bge_finetune/readme.md
2025-07-22 16:55:25 +08:00

51 KiB
Raw Blame History

BGE Fine-tuning for Chinese Book Retrieval

A comprehensive, production-ready implementation for fine-tuning BAAI BGE-M3 embedding model and BGE-Reranker cross-encoder model for optimal retrieval performance on Chinese book content, with English query support.


Requirements & Compatibility

  • Python >= 3.10
  • CUDA >= 12.8 (for NVIDIA RTX 5090 and similar)
  • PyTorch >= 2.2.0 (Use nightly build for CUDA 12.8+ support)
  • See requirements.txt for all dependencies
  • Ascend NPU support: Install torch_npu and set [hardware]/[ascend] config as needed

Hardware Support

  • NVIDIA GPUs: RTX 5090, A100, H100 (CUDA 12.8+)
  • Huawei Ascend NPUs: Ascend 910, 910B (with torch_npu)
  • CPU: Fallback support for development/testing

🚀 Features

Core Capabilities

  • State-of-the-art Techniques: InfoNCE loss with proper temperature scaling, self-knowledge distillation, ANCE-style hard negative mining, RocketQAv2 joint training
  • Multi-functionality: Dense, sparse (SPLADE-style), and multi-vector (ColBERT-style) retrieval
  • Cross-lingual Support: Chinese primary with English query compatibility
  • Platform Compatibility: NVIDIA CUDA and Huawei Ascend NPU support with automatic detection
  • Comprehensive Evaluation: Built-in metrics (Recall@k, MRR@k, NDCG@k, MAP) with additional success rate and coverage metrics
  • Production Ready: Robust error handling, atomic checkpointing, comprehensive logging, and distributed training

Recent Enhancements

  • Enhanced Error Recovery: Training continues despite batch failures with configurable thresholds
  • Atomic Checkpoint Saving: Prevents corruption with temporary directories and backup mechanisms
  • Improved Memory Management: Periodic cache clearing and gradient checkpointing for large models
  • Configuration Precedence: Clear hierarchy (CLI > Model-specific > Hardware > Training > Defaults)
  • Unified AMP Imports: Consistent torch.cuda.amp usage across all modules
  • Comprehensive Metrics: Fixed MRR/NDCG implementations with edge case handling
  • Ascend NPU Backend: Complete support module with device detection and optimization
  • Enhanced Dataset Validation: Robust format detection with helpful error messages

📋 Technical Overview

BGE-M3 (Embedding Model)

  • InfoNCE Loss with temperature τ=0.02 (verified from BGE-M3 paper)
  • Proper Loss Normalization: Fixed cross-entropy calculation in listwise loss
  • Self-Knowledge Distillation combining dense, sparse, and multi-vector signals
  • ANCE-style Hard Negative Mining with asynchronous ANN index updates
  • Multi-granularity Support up to 8192 tokens
  • Unified Training for all three functionalities
  • Gradient Checkpointing for memory-efficient training on large batches

BGE-Reranker (Cross-encoder)

  • Listwise Cross-Entropy loss (fixed implementation, not KL divergence)
  • Dynamic Listwise Distillation following RocketQAv2 approach
  • Hard Negative Selection from retriever results
  • Knowledge Distillation from teacher models
  • Mixed Precision Training with bfloat16 support for RTX 5090

Joint Training (RocketQAv2)

  • Mutual Distillation between retriever and reranker
  • Dynamic Hard Negative Mining using evolving models
  • Hybrid Data Augmentation strategies
  • Error-resilient Training Loop with failure recovery

🛠 Installation

Hardware Requirements

  • For RTX 5090: Python 3.10+, CUDA >= 12.8, Driver >= 545.x
  • For Ascend NPU: Python 3.10+, CANN >= 7.0, torch_npu
  • Memory: 24GB+ GPU memory recommended for full model training

Step 1: Install PyTorch (CUDA 12.8+ for RTX 5090)

# Install PyTorch nightly with CUDA 12.8 support
pip3 install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu128

Step 2: Install Dependencies

pip install -r requirements.txt

Step 3: (Optional) Platform-specific Setup

For Huawei Ascend NPU:

# Install torch_npu (follow official Huawei documentation)
pip install torch_npu

# Set environment variables
export ASCEND_RT_VISIBLE_DEVICES=0,1,2,3  # Your NPU devices

For Distributed Training:

# No additional setup needed - automatic backend selection:
# - NCCL for NVIDIA GPUs
# - HCCL for Ascend NPUs
# - Gloo for CPU

Step 4: Download Models (Optional - auto-downloads if not present)

# Models will auto-download on first use, or manually download:
huggingface-cli download BAAI/bge-m3 --local-dir ./models/bge-m3
huggingface-cli download BAAI/bge-reranker-base --local-dir ./models/bge-reranker-base

🔧 Configuration

Configuration System Overview

The project uses a hierarchical configuration system with clear precedence rules and automatic validation. Configuration is managed through config.toml with programmatic overrides via command-line arguments.

Configuration Precedence Order (Highest to Lowest)

  1. Command-line arguments - Override any config file settings
  2. Model-specific sections - [m3] for BGE-M3, [reranker] for BGE-Reranker
  3. Hardware-specific sections - [hardware], [ascend] for device-specific settings
  4. Training defaults - [training] section for general defaults
  5. Dataclass field defaults - Built-in defaults in code

Example precedence in action:

# config.toml
[training]
default_batch_size = 4

[hardware]
default_batch_size = 8  # Overrides [training] for hardware optimization

[m3]
per_device_train_batch_size = 16  # Overrides both [hardware] and [training] for M3 model

Main Configuration File (config.toml)

# API Configuration for augmentation and external services
[api]
base_url = "https://api.deepseek.com/v1"
api_key = "your-api-key-here"  # Replace with your actual API key
model = "deepseek-chat"
max_retries = 3
timeout = 30
temperature = 0.7

# General training defaults
[training]
default_batch_size = 4
default_learning_rate = 1e-5
default_num_epochs = 3
default_warmup_ratio = 0.1
default_weight_decay = 0.01
gradient_accumulation_steps = 4
max_grad_norm = 1.0
logging_steps = 100
save_strategy = "epoch"
evaluation_strategy = "epoch"
load_best_model_at_end = true
metric_for_best_model = "eval_loss"
greater_is_better = false
save_total_limit = 3
fp16 = true
bf16 = false  # Enable for A100/RTX5090
gradient_checkpointing = false
dataloader_num_workers = 4
remove_unused_columns = true
label_names = ["labels"]

# Model paths with automatic download
  [model_paths]
bge_m3 = "./models/bge-m3"
bge_reranker = "./models/bge-reranker-base"
  cache_dir = "./cache/models"

# Data processing configuration
[data]
max_query_length = 64
max_passage_length = 512
max_seq_length = 512
validation_split = 0.1
random_seed = 42
preprocessing_num_workers = 4
overwrite_cache = false
pad_to_max_length = false
return_tensors = "pt"

# Hardware-specific optimizations
[hardware]
device_type = "cuda"  # cuda, npu, cpu
mixed_precision = true
mixed_precision_dtype = "fp16"  # fp16, bf16
enable_tf32 = true  # For Ampere GPUs
cudnn_benchmark = true
cudnn_deterministic = false
per_device_train_batch_size = 8
per_device_eval_batch_size = 16

# Distributed training configuration
[distributed]
backend = "nccl"  # nccl, gloo, hccl (auto-detected if not set)
timeout_minutes = 30
find_unused_parameters = false
gradient_as_bucket_view = true
static_graph = false

# Huawei Ascend NPU specific settings
[ascend]
device_type = "npu"
backend = "hccl"
mixed_precision = true
mixed_precision_dtype = "fp16"
visible_devices = "0,1,2,3"
enable_graph_mode = false
enable_profiling = false
profiling_start_step = 10
profiling_stop_step = 20

# BGE-M3 specific configuration
[m3]
model_name_or_path = "BAAI/bge-m3"
learning_rate = 1e-5
per_device_train_batch_size = 16
per_device_eval_batch_size = 32
num_train_epochs = 5
warmup_ratio = 0.1
temperature = 0.02  # InfoNCE temperature (from paper)
use_self_distill = true
self_distill_weight = 1.0
use_hard_negatives = true
num_hard_negatives = 15
negatives_cross_batch = true
pooling_method = "cls"
normalize_embeddings = true
use_sparse = true
use_colbert = true
colbert_dim = 32

# BGE-Reranker specific configuration
[reranker]
model_name_or_path = "BAAI/bge-reranker-base"
learning_rate = 6e-5
per_device_train_batch_size = 8
per_device_eval_batch_size = 16
num_train_epochs = 3
warmup_ratio = 0.06
train_group_size = 16
loss_type = "listwise_ce"  # listwise_ce, ranknet, lambdarank
use_gradient_checkpointing = true
temperature = 1.0
use_distillation = false
teacher_model_path = "BAAI/bge-reranker-large"
distillation_weight = 0.5

# Data augmentation settings
[augmentation]
enable_query_augmentation = false
enable_passage_augmentation = false
augmentation_methods = ["paraphrase", "back_translation", "mask_and_fill"]
num_augmentations_per_sample = 2
augmentation_temperature = 0.8
api_augmentation = true
local_augmentation_model = "google/flan-t5-base"

# Hard negative mining configuration
[hard_negative_mining]
enable_mining = true
mining_strategy = "ance"  # ance, random, bm25
num_hard_negatives = 15
negative_depth = 200  # How deep to search for negatives
update_interval = 1000  # Steps between mining updates
margin = 0.1
use_cross_batch_negatives = true
dedup_negatives = true

# FAISS index settings for mining
[faiss]
index_type = "IVF"  # Flat, IVF, HNSW
index_params = {nlist = 100, nprobe = 10}
use_gpu = true
normalize_vectors = true

# Checkpoint and recovery settings
[checkpoint]
save_strategy = "epoch"
save_steps = 1000
save_total_limit = 3
save_on_each_node = false
resume_from_checkpoint = true
checkpoint_save_dir = "./checkpoints"
use_atomic_save = true  # Prevents corruption
keep_backup = true

# Logging configuration
[logging]
logging_dir = "./logs"
logging_strategy = "steps"
logging_steps = 100
logging_first_step = true
report_to = ["tensorboard", "wandb"]
wandb_project = "bge-finetuning"
wandb_run_name = null  # Auto-generated if null
disable_tqdm = false
tqdm_mininterval = 1

# Evaluation settings
[evaluation]
evaluation_strategy = "epoch"
eval_steps = 500
eval_batch_size = 32
metric_for_best_model = "eval_loss"
load_best_model_at_end = true
compute_metrics_fn = "accuracy_and_f1"
greater_is_better = false
eval_accumulation_steps = null

# Memory optimization
[memory]
gradient_checkpointing = true
gradient_checkpointing_kwargs = {use_reentrant = false}
optim = "adamw_torch"  # adamw_torch, adamw_hf, sgd, adafactor
optim_args = null
clear_cache_interval = 100  # Steps between cache clearing
max_memory_mb = null  # Max memory usage before clearing

# Error handling and recovery
[error_handling]
continue_on_error = true
max_failed_batches = 10
error_log_file = "./logs/errors.log"
save_on_error = true
retry_failed_batches = true
num_retries = 3

Command-Line Override Examples

# Override batch size and learning rate
python scripts/train_m3.py \
  --per_device_train_batch_size 32 \
  --learning_rate 2e-5

# Override hardware settings
python scripts/train_m3.py \
  --device npu \
  --mixed_precision_dtype bf16

# Override multiple config sections
python scripts/train_m3.py \
  --config_file custom_config.toml \
  --temperature 0.05 \
  --num_hard_negatives 20 \
  --gradient_checkpointing

Environment Variable Support

# Set API key via environment variable
export BGE_API_KEY="your-api-key"

# Override config file location
export BGE_CONFIG_FILE="/path/to/custom/config.toml"

# Set distributed training environment
export CUDA_VISIBLE_DEVICES="0,1,2,3"
export WORLD_SIZE=4

Configuration Validation

The configuration system automatically validates settings:

# In your code
from config.base_config import BaseConfig

config = BaseConfig.from_toml("config.toml")

# Automatic validation includes:
# - Type checking (int, float, str, bool, list, dict)
# - Value range validation (e.g., learning_rate > 0)
# - Compatibility checks (e.g., bf16 requires compatible hardware)
# - Missing required fields
# - Deprecated parameter warnings

Dynamic Configuration

# Programmatic configuration updates
from config.m3_config import M3Config

config = M3Config.from_toml("config.toml")

# Update based on hardware
if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8:
    config.mixed_precision_dtype = "bf16"
    config.per_device_train_batch_size *= 2

# Update based on dataset size
if dataset_size > 1000000:
    config.gradient_accumulation_steps = 8
    config.logging_steps = 500

Best Practices

  1. Start with defaults: The provided config.toml has sensible defaults
  2. Hardware-specific tuning: Use [hardware] section for GPU/NPU optimization
  3. Model-specific settings: Use [m3] or [reranker] for model-specific params
  4. Environment isolation: Use different config files for dev/staging/prod
  5. Version control: Track config.toml changes in git
  6. Secrets management: Use environment variables for API keys

🚀 Quick Start

Step 1: Prepare Your Data

Create your training data in JSONL format. Each line should be a valid JSON object.

For embedding models (BGE-M3):

{"query": "What is machine learning?", "pos": ["Machine learning is..."], "neg": ["The weather is..."]}
{"query": "How to cook pasta?", "pos": ["Boil water and add pasta..."], "neg": ["Machine learning..."]}

For reranker models:

{"query": "What is AI?", "passage": "AI is artificial intelligence...", "label": 1}
{"query": "What is AI?", "passage": "The weather today is sunny...", "label": 0}

Step 2: Fine-tune Your Model

# Basic training with automatic hardware detection
python scripts/train_m3.py \
  --train_data data/train.jsonl \
  --output_dir ./output/bge-m3-finetuned \
  --num_train_epochs 3 \
  --per_device_train_batch_size 4 \
  --learning_rate 1e-5

# Advanced training with all features enabled
python scripts/train_m3.py \
  --train_data data/train.jsonl \
  --output_dir ./output/bge-m3-advanced \
  --num_train_epochs 5 \
  --per_device_train_batch_size 8 \
  --learning_rate 1e-5 \
  --use_self_distill \
  --use_hard_negatives \
  --gradient_checkpointing \
  --bf16 \
  --save_total_limit 3 \
  --logging_steps 100 \
  --evaluation_strategy "steps" \
  --eval_steps 500 \
  --metric_for_best_model "eval_loss" \
  --load_best_model_at_end

3. Fine-tune BGE-Reranker (Cross-encoder)

# Basic reranker training
python scripts/train_reranker.py \
  --train_data data/train.jsonl \
  --output_dir ./output/bge-reranker-finetuned \
  --learning_rate 6e-5 \
  --train_group_size 16 \
  --loss_type listwise_ce

# Advanced reranker training with knowledge distillation
python scripts/train_reranker.py \
  --train_data data/train.jsonl \
  --output_dir ./output/bge-reranker-advanced \
  --learning_rate 6e-5 \
  --train_group_size 16 \
  --loss_type listwise_ce \
  --use_distillation \
  --teacher_model_path BAAI/bge-reranker-large \
  --bf16 \
  --gradient_checkpointing \
  --save_strategy "epoch" \
  --evaluation_strategy "epoch"

4. Joint Training (RocketQAv2-style)

# Joint training with mutual distillation
python scripts/train_joint.py \
  --train_data data/train.jsonl \
  --output_dir ./output/joint-training \
  --use_dynamic_distillation \
  --joint_loss_weight 0.3 \
  --bf16 \
  --gradient_checkpointing \
  --max_grad_norm 1.0 \
  --dataloader_num_workers 4

5. Multi-GPU / Distributed Training

# NVIDIA GPUs (automatic NCCL backend)
torchrun --nproc_per_node=4 scripts/train_m3.py \
  --train_data data/train.jsonl \
  --output_dir ./output/distributed \
  --per_device_train_batch_size 16 \
  --gradient_accumulation_steps 2 \
  --bf16 \
  --gradient_checkpointing

# Huawei Ascend NPUs (automatic HCCL backend)
# First, configure config.toml [ascend] section, then:
python scripts/train_m3.py \
  --train_data data/train.jsonl \
  --output_dir ./output/ascend \
  --device npu

📊 Data Format

Canonical Format (Required Structure)

All data formats must contain these required fields:

  • query (string): The search query or question
  • pos (list of strings): Positive (relevant) passages
  • neg (list of strings): Negative (irrelevant) passages

Optional fields for advanced training:

  • pos_scores (list of floats): Relevance scores for positives (0.0-1.0)
  • neg_scores (list of floats): Relevance scores for negatives (0.0-1.0)
  • prompt (string): Custom prompt for encoding

Supported Input Formats

Our dataset loader automatically detects and validates the following formats:

{"query": "What is machine learning?", "pos": ["Machine learning is..."], "neg": ["Deep learning is..."]}
{"query": "深度学习的应用", "pos": ["深度学习在计算机视觉..."], "neg": ["机器学习是..."]}

2. JSON (Array Format)

[
  {
    "query": "What is BGE?",
    "pos": ["BGE is a family of embedding models..."],
    "neg": ["BERT is a transformer model..."],
    "pos_scores": [1.0],
    "neg_scores": [0.2]
  }
]

3. CSV/TSV

query,pos,neg,pos_scores,neg_scores
"What is machine learning?","['ML is...', 'Machine learning refers to...']","['DL is...']","[1.0, 0.9]","[0.1]"

4. DuReader Format

{
  "data": [{
    "question": "什么是机器学习?",
    "answers": ["机器学习是..."],
    "documents": [{
      "content": "深度学习是...",
      "is_relevant": false
    }]
  }]
}

5. MS MARCO Format (TSV)

query_id	query	doc_id	document	relevance
1001	What is ML?	2001	Machine learning is...	1
1001	What is ML?	2002	Deep learning is...	0

6. XML Format (TREC-style)

<topics>
  <topic>
    <query>Information retrieval basics</query>
    <positive>IR is the process of obtaining...</positive>
    <negative>Database management is...</negative>
  </topic>
</topics>

Data Validation & Error Handling

The dataset loader performs comprehensive validation:

  • Format Detection: Automatic format inference with helpful error messages
  • Field Validation: Ensures all required fields are present and properly typed
  • Content Validation: Checks for empty queries/passages, invalid scores
  • Encoding Validation: Handles various text encodings (UTF-8, GB2312, etc.)
  • Size Validation: Warns about extremely long sequences

Example validation output:

Dataset validation completed:
✓ Format: JSONL
✓ Total samples: 10,000
✓ Valid samples: 9,998
✗ Invalid samples: 2
  - Line 1523: Empty query field
  - Line 7891: Missing 'neg' field
✓ Encoding: UTF-8
⚠ Warnings:
  - 15 samples have passages > 8192 tokens (will be truncated)

Converting Data Formats

Use the built-in preprocessor to convert between formats:

from data.preprocessing import DataPreprocessor

preprocessor = DataPreprocessor()

# Convert any format to canonical JSONL
preprocessor.convert_to_canonical(
    input_file="data/msmarco.tsv",
    output_file="data/train.jsonl",
    input_format="msmarco"  # or "auto" for automatic detection
)

# Batch conversion with validation
preprocessor.batch_convert(
    input_files=["data/*.csv", "data/*.json"],
    output_dir="data/processed/",
    validate=True,
    max_workers=4
)

Working with Book/Document Collections

For training on books or large documents:

  1. Split documents into passages:
from data.preprocessing import DocumentSplitter

splitter = DocumentSplitter(
    chunk_size=512,
    overlap=50,
    split_by="sentence"  # or "paragraph", "sliding_window"
)

passages = splitter.split_document("path/to/book.txt")
  1. Generate training pairs:
from data.preprocessing import TrainingPairGenerator

generator = TrainingPairGenerator()

# Generate query-passage pairs with hard negatives
training_data = generator.generate_from_passages(
    passages=passages,
    queries_per_passage=2,
    hard_negative_ratio=0.7,
    use_bm25=True
)

# Save in canonical format
generator.save_jsonl(training_data, "data/book_train.jsonl")
  1. Use data optimization:
# Optimize data for BGE-M3 training
python -m data.optimization bge-m3 data/book_train.jsonl \
  --output_dir ./optimized_data \
  --hard_negative_ratio 0.8 \
  --augment_queries

Memory-Efficient Data Loading

For large datasets, use streaming mode:

# In your training script
dataset = ContrastiveDataset(
    data_file="data/large_dataset.jsonl",
    tokenizer=tokenizer,
    streaming=True,  # Enable streaming
    buffer_size=10000  # Buffer size for shuffling
)

📖 Training with Book .txt Files

  • You cannot train directly on a raw .txt book file.
  • Convert your book to a supported format:
    • Split the book into passages (e.g., by chapter or paragraph).
    • Create queries and annotate which passages are positive (relevant) and negative (irrelevant) for each query.
    • Example:
      {"query": "What happens in Chapter 1?", "pos": ["Chapter 1: ..."], "neg": ["Chapter 2: ..."]}
      
  • Use the provided DataPreprocessor to convert CSV/TSV/JSON/XML to JSONL.

Note: All formats must be convertible to the canonical JSONL structure above. Use the provided DataPreprocessor to convert between formats.


⚙️ Configuration

Parameter Precedence

  • [hardware] overrides [training] for overlapping parameters (e.g., batch size).
  • [m3] and [reranker] override [training]/[hardware] for their respective models.
  • For Ascend NPU, use the [ascend] section and set device_type = "npu".

Hardware Compatibility

  • NVIDIA RTX 5090: Use torch >=2.7.1 and CUDA 12.8+ (see above).
  • Huawei Ascend NPU: Install torch_npu, set [hardware] or [ascend] config, and ensure your environment matches Ascend requirements.

API Configuration (config.toml)

[api]
base_url = "https://api.deepseek.com/v1"
api_key = "your-api-key-here"
model = "deepseek-chat"

[training]
default_batch_size = 4
default_learning_rate = 1e-5

[model_paths]
bge_m3 = "./models/bge-m3"
bge_reranker = "./models/bge-reranker-base"

Configuration Parameter Reference

[api]

Parameter Type Default Description
base_url str "https://api.deepseek..." API endpoint for augmentation
api_key str "your-api-key-here" API key for external services
model str "deepseek-chat" Model name for API
max_retries int 3 Max API retries
timeout int 30 API timeout (seconds)
temperature float 0.7 API sampling temperature

[training]

Parameter Type Default Description
default_batch_size int 4 Default batch size for training
default_learning_rate float 1e-5 Default learning rate
default_num_epochs int 3 Default number of epochs
default_warmup_ratio float 0.1 Warmup steps ratio
default_weight_decay float 0.01 Weight decay for optimizer
gradient_accumulation_steps int 4 Steps to accumulate gradients

[model_paths]

Parameter Type Default Description
bge_m3 str "./models/bge-m3" Path to BGE-M3 model
bge_reranker str "./models/bge-reranker-base" Path to BGE-Reranker model
cache_dir str "./cache/models" Model cache directory

[data]

Parameter Type Default Description
max_query_length int 64 Max query length
max_passage_length int 512 Max passage length
max_seq_length int 512 Max sequence length
validation_split float 0.1 Validation split ratio
random_seed int 42 Random seed for reproducibility

[augmentation]

Parameter Type Default Description
enable_query_augmentation bool false Enable query augmentation
enable_passage_augmentation bool false Enable passage augmentation
augmentation_methods list ["paraphrase", "back_translation"] Augmentation methods to use
num_augmentations_per_sample int 2 Number of augmentations per sample
augmentation_temperature float 0.8 Temperature for augmentation
api_first bool true Prefer LLM API for all augmentation if enabled

[hard_negative_mining]

Parameter Type Default Description
enable_mining bool true Enable hard negative mining
num_hard_negatives int 15 Number of hard negatives to mine
negative_range_min int 10 Min negative index for mining
negative_range_max int 100 Max negative index for mining
margin float 0.1 Margin for ANCE mining
index_refresh_interval int 1000 Steps between index refreshes
index_type str "IVF" FAISS index type
nlist int 100 Number of clusters for IVF

[m3] and [reranker]

  • All model-specific parameters are now documented in the config classes and code docstrings. See config/m3_config.py and config/reranker_config.py for full details.

Distributed Training

Overview

The project supports distributed training across multiple GPUs/NPUs with automatic backend detection, error recovery, and optimized communication strategies. The distributed training system is fully configuration-driven and handles both data-parallel and model-parallel strategies.

Automatic Backend Detection

The system automatically selects the optimal backend based on your hardware:

  • NVIDIA GPUs: NCCL (fastest for GPU-to-GPU communication)
  • Huawei Ascend NPUs: HCCL (optimized for Ascend architecture)
  • CPU or Mixed: Gloo (fallback for heterogeneous setups)

Basic Usage

Single-Node Multi-GPU (NVIDIA)

# Automatic detection and setup
torchrun --nproc_per_node=4 scripts/train_m3.py \
  --train_data data/train.jsonl \
  --output_dir ./output/distributed

# With specific GPU selection
CUDA_VISIBLE_DEVICES=0,1,2,3 torchrun --nproc_per_node=4 scripts/train_m3.py \
  --train_data data/train.jsonl \
  --output_dir ./output/distributed \
  --gradient_accumulation_steps 2

Multi-Node Training (NVIDIA)

# On node 0 (master)
torchrun --nproc_per_node=4 --nnodes=2 --node_rank=0 \
  --master_addr=192.168.1.1 --master_port=29500 \
  scripts/train_m3.py --train_data data/train.jsonl

# On node 1
torchrun --nproc_per_node=4 --nnodes=2 --node_rank=1 \
  --master_addr=192.168.1.1 --master_port=29500 \
  scripts/train_m3.py --train_data data/train.jsonl

Huawei Ascend NPU Training

# Single-node multi-NPU
python scripts/train_m3.py \
  --train_data data/train.jsonl \
  --output_dir ./output/ascend \
  --device npu

# With specific NPU selection
export ASCEND_RT_VISIBLE_DEVICES=0,1,2,3
python scripts/train_m3.py \
  --train_data data/train.jsonl \
  --output_dir ./output/ascend

Configuration Options

```toml

config.toml

[distributed] backend = "auto" # auto, nccl, gloo, hccl timeout_minutes = 30 find_unused_parameters = false gradient_as_bucket_view = true broadcast_buffers = true bucket_cap_mb = 25 use_distributed_optimizer = false

Performance optimizations

[distributed.optimization] overlap_comm = true gradient_compression = false compression_type = "none" # none, powerSGD, fp16 all_reduce_fusion = true fusion_threshold_mb = 64

Fault tolerance

[distributed.fault_tolerance] enable_checkpointing = true checkpoint_interval = 1000 restart_on_failure = true max_restart_attempts = 3 heartbeat_interval = 60


### Advanced Features

#### 1. Gradient Accumulation with Distributed Training

```python
# Effective batch size = per_device_batch * gradient_accumulation * num_gpus
# Example: 8 * 4 * 4 = 128 effective batch size
python scripts/train_m3.py \
  --per_device_train_batch_size 8 \
  --gradient_accumulation_steps 4 \
  --distributed

2. Mixed Precision with Distributed

# Automatic mixed precision for faster training
torchrun --nproc_per_node=4 scripts/train_m3.py \
  --bf16 \
  --tf32  # Additional optimization for Ampere GPUs

3. DeepSpeed Integration

# With DeepSpeed for extreme scale
deepspeed scripts/train_m3.py \
  --deepspeed deepspeed_config.json \
  --train_data data/train.jsonl

Example DeepSpeed config:

{
  "train_batch_size": "auto",
  "gradient_accumulation_steps": "auto",
  "fp16": {
    "enabled": "auto",
    "loss_scale": 0,
    "loss_scale_window": 1000
  },
  "zero_optimization": {
    "stage": 2,
    "offload_optimizer": {
      "device": "cpu",
      "pin_memory": true
    },
    "allgather_partitions": true,
    "allgather_bucket_size": 2e8,
    "overlap_comm": true,
    "reduce_scatter": true,
    "reduce_bucket_size": 2e8,
    "contiguous_gradients": true
  }
}

4. Fault Tolerance and Recovery

The distributed training system includes robust error handling:

# Training automatically recovers from:
# - Temporary network failures
# - GPU memory errors (with retry)
# - Node failures (with checkpoint recovery)
# - Communication timeouts

5. Performance Monitoring

# Enable distributed profiling
python scripts/train_m3.py \
  --enable_profiling \
  --profiling_start_step 100 \
  --profiling_end_step 200

# Monitor with tensorboard
tensorboard --logdir ./logs --bind_all

Best Practices

  1. Data Loading

    # Use distributed sampler (automatic)
    # Set num_workers based on CPUs
    --dataloader_num_workers 4
    
  2. Batch Size Scaling

    # Scale learning rate with batch size
    # LR = base_lr * (total_batch_size / base_batch_size)
    --learning_rate 1e-5 * (num_gpus * batch_size / 16)
    
  3. Communication Optimization

    # Reduce communication overhead
    --gradient_accumulation_steps 4
    --broadcast_buffers false  # If not needed
    
  4. Memory Optimization

    # For large models
    --gradient_checkpointing
    --zero_stage 2  # With DeepSpeed
    

Troubleshooting

NCCL Errors

# Debug NCCL issues
export NCCL_DEBUG=INFO
export NCCL_DEBUG_SUBSYS=ALL

# Common fixes
export NCCL_IB_DISABLE=1  # Disable InfiniBand
export NCCL_P2P_DISABLE=1  # Disable P2P

Hanging Training

# Set timeout for debugging
export NCCL_TIMEOUT=300  # 5 minutes
export NCCL_ASYNC_ERROR_HANDLING=1

Memory Issues

# Clear cache periodically
--clear_cache_interval 100

# Use gradient checkpointing
--gradient_checkpointing

Performance Benchmarks

Typical speedup with distributed training:

  • 2 GPUs: 1.8-1.9x speedup
  • 4 GPUs: 3.5-3.8x speedup
  • 8 GPUs: 6.5-7.5x speedup

Factors affecting performance:

  • Model size (larger models scale better)
  • Batch size (larger batches reduce communication overhead)
  • Network bandwidth (especially for multi-node)
  • Data loading speed

📊 Evaluation Metrics

Overview

The evaluation system provides comprehensive metrics for both retrieval and reranking tasks, with support for graded relevance, statistical significance testing, and performance benchmarking.

Core Metrics

Retrieval Metrics

  1. Recall@K - Fraction of relevant documents retrieved in top K

    recall = |retrieved_relevant| / |total_relevant|
    
  2. Precision@K - Fraction of retrieved documents that are relevant

    precision = |retrieved_relevant| / K
    
  3. Mean Average Precision (MAP) - Average precision at each relevant position

    MAP = mean([AP(q) for q in queries])
    where AP = sum([P@k * rel(k) for k in positions]) / |relevant|
    
  4. Mean Reciprocal Rank (MRR@K) - Average reciprocal of first relevant result rank

    MRR = mean([1/rank_of_first_relevant for q in queries])
    
  5. Normalized Discounted Cumulative Gain (NDCG@K) - Position-weighted relevance score

    DCG@K = sum([rel(i) / log2(i + 1) for i in 1..K])
    NDCG@K = DCG@K / IDCG@K
    
  6. Success Rate@K - Fraction of queries with at least one relevant result in top K

    success_rate = |queries_with_relevant_in_topK| / |total_queries|
    
  7. Coverage@K - Fraction of all relevant documents found across all queries

    coverage = |unique_relevant_retrieved| / |total_unique_relevant|
    

Reranking Metrics

  1. Accuracy@1 - Top-1 ranking accuracy
  2. MRR - Mean reciprocal rank for reranked results
  3. NDCG@K - Ranking quality with graded relevance
  4. ERR (Expected Reciprocal Rank) - Cascade model metric

Usage Examples

Basic Evaluation

# Evaluate retriever
python scripts/evaluate.py \
  --eval_type retriever \
  --model_path ./output/bge-m3-finetuned \
  --eval_data data/test.jsonl \
  --k_values 1 5 10 20 50 \
  --output_file results/retriever_metrics.json

# Evaluate reranker
python scripts/evaluate.py \
  --eval_type reranker \
  --model_path ./output/bge-reranker-finetuned \
  --eval_data data/test.jsonl \
  --output_file results/reranker_metrics.json

# Evaluate full pipeline
python scripts/evaluate.py \
  --eval_type pipeline \
  --retriever_path ./output/bge-m3-finetuned \
  --reranker_path ./output/bge-reranker-finetuned \
  --eval_data data/test.jsonl \
  --retrieval_top_k 50 \
  --rerank_top_k 10

Advanced Evaluation with Graded Relevance

# Data format with relevance scores
{
  "query": "machine learning basics",
  "documents": [
    {"text": "Introduction to ML", "relevance": 2},  # Highly relevant
    {"text": "Deep learning guide", "relevance": 1},  # Somewhat relevant
    {"text": "Database systems", "relevance": 0}      # Not relevant
  ]
}
# Evaluate with graded relevance
python scripts/evaluate.py \
  --eval_type retriever \
  --model_path ./output/bge-m3-finetuned \
  --eval_data data/test_graded.jsonl \
  --use_graded_relevance \
  --relevance_threshold 1 \
  --ndcg_k_values 5 10 20

Statistical Significance Testing

# Compare two models with significance testing
python scripts/compare_retriever.py \
  --baseline_model_path BAAI/bge-m3 \
  --finetuned_model_path ./output/bge-m3-finetuned \
  --data_path data/test.jsonl \
  --num_bootstrap_samples 1000 \
  --confidence_level 0.95

# Output includes:
# - Metric comparisons with confidence intervals
# - Statistical significance (p-values)
# - Effect sizes (Cohen's d)

Metric Computation Details

Handling Edge Cases

  1. Empty Results

    • Recall: 0 if no relevant documents exist, else 0
    • Precision: 0 for empty results
    • MRR: 0 if no relevant document found
    • NDCG: 0 for empty results
  2. Tied Scores

    • Uses stable sorting to maintain consistency
    • Averages metrics across tied positions
  3. Missing Relevance Labels

    • Binary relevance: treats unlabeled as non-relevant
    • Graded relevance: requires explicit labels

Efficiency Optimizations

# Batch evaluation for large datasets
evaluator = Evaluator(
    model=model,
    batch_size=128,
    use_fp16=True,
    device="cuda"
)

# Streaming evaluation for memory efficiency
results = evaluator.evaluate_streaming(
    data_path="large_dataset.jsonl",
    chunk_size=10000
)

Performance Benchmarking

# Benchmark inference speed
python scripts/benchmark.py \
  --model_type retriever \
  --model_path ./output/bge-m3-finetuned \
  --data_path data/test.jsonl \
  --batch_sizes 1 8 16 32 64 \
  --num_warmup_steps 10 \
  --num_benchmark_steps 100

# Output includes:
# - Throughput (queries/second)
# - Latency percentiles (p50, p90, p95, p99)
# - Memory usage
# - Hardware utilization

Custom Metrics

from evaluation.metrics import BaseMetric

class CustomMetric(BaseMetric):
    def compute(self, predictions, references, k=10):
        # Your custom metric logic
        return score

# Register custom metric
from evaluation.evaluator import Evaluator
evaluator = Evaluator(model=model)
evaluator.add_metric("custom_metric", CustomMetric())

Visualization and Reporting

# Generate evaluation report
from evaluation.reporting import EvaluationReport

report = EvaluationReport(results)
report.generate_html("evaluation_report.html")
report.generate_latex("evaluation_report.tex")
report.plot_metrics("plots/")

# Example plots:
# - Recall@K curves
# - Precision-Recall curves
# - NDCG@K comparison
# - Query-level performance distribution

Best Practices

  1. Use Appropriate K Values

    • Small K (1, 5, 10) for user-facing metrics
    • Large K (50, 100) for reranking scenarios
  2. Consider Query Types

    • Navigational queries: Focus on MRR, Success@1
    • Informational queries: Focus on NDCG, Recall@K
  3. Statistical Rigor

    • Always use test sets with 1000+ queries
    • Report confidence intervals
    • Use paired tests for model comparison
  4. Graded Relevance

    • Use NDCG for nuanced evaluation
    • Define clear relevance guidelines
    • Consider query-specific relevance scales

📈 Model Evaluation & Benchmarking

1. Detailed Accuracy Evaluation

Use scripts/evaluate.py for comprehensive accuracy evaluation of retriever, reranker, or the full pipeline. Supports all standard metrics (MRR, Recall@k, MAP, NDCG, etc.) and baseline comparison.

Example:

python scripts/evaluate.py --retriever_model path/to/finetuned --eval_data path/to/data.jsonl --eval_retriever --k_values 1 5 10
python scripts/evaluate.py --reranker_model path/to/finetuned --eval_data path/to/data.jsonl --eval_reranker
python scripts/evaluate.py --retriever_model path/to/finetuned --reranker_model path/to/finetuned --eval_data path/to/data.jsonl --eval_pipeline --compare_baseline --base_retriever path/to/baseline

2. Quick Performance Benchmarking

Use scripts/benchmark.py for fast, single-model performance profiling (throughput, latency, memory). This script does not report accuracy metrics.

Example:

python scripts/benchmark.py --model_type retriever --model_path path/to/model --data_path path/to/data.jsonl --batch_size 16 --device cuda
python scripts/benchmark.py --model_type reranker --model_path path/to/model --data_path path/to/data.jsonl --batch_size 16 --device cuda

3. Fine-tuned vs. Baseline Comparison (Accuracy & Performance)

Use scripts/compare_retriever.py and scripts/compare_reranker.py to compare a fine-tuned model against a baseline, reporting both accuracy and performance metrics side by side, including the delta (improvement) for each metric.

Example:

python scripts/compare_retriever.py --finetuned_model_path path/to/finetuned --baseline_model_path path/to/baseline --data_path path/to/data.jsonl --batch_size 16 --device cuda
python scripts/compare_reranker.py --finetuned_model_path path/to/finetuned --baseline_model_path path/to/baseline --data_path path/to/data.jsonl --batch_size 16 --device cuda
  • Results are saved to a file and include a table of metrics for both models and the improvement (delta).

🖥️ Ascend NPU Section

  • Install: pip install torch_npu (see Ascend docs)
  • Edit config.toml:
    [ascend]
    device_type = "npu"
    backend = "hccl"
    mixed_precision = true
    visible_devices = "0,1,2,3"
    
  • Launch training:
    python3.10 scripts/train_m3.py --train_data data/train.jsonl --output_dir ./output/npu
    
  • Distributed: Use Ascend's distributed launcher or set backend = "hccl" in [ascend].
  • All device/backend selection is config-driven.

🚀 Performance Optimization

Multi-GPU Training

torchrun --nproc_per_node=2 scripts/train_m3.py \
  --train_data data/train.jsonl \
  --output_dir ./output/distributed \
  --fp16 \
  --gradient_checkpointing

Memory Optimization

  • Gradient Checkpointing: Reduces memory usage
  • Mixed Precision: FP16 training for RTX 5090
  • Gradient Accumulation: Effective larger batch sizes
  • DeepSpeed Integration: For large-scale training

Huawei Ascend Compatibility

# For Ascend deployment
import torch_npu
device = torch.device("npu")
model.to(device)

📝 Recent Changes & Improvements

Simplified Data Pipeline (Latest)

  • Removed unnecessary PKL conversion pipeline: The data/optimization.py module has been removed
    • Performance benefits came from pre-tokenization, not PKL format
    • PKL files caused memory issues and added complexity without real benefits
  • Added in-memory caching option: For small datasets (< 100k samples), enable cache_in_memory=True
  • Kept JSONL as primary format: Industry standard, human-readable, and efficient
  • All core functionality intact: Training, evaluation, and model features unchanged

🚀 Performance Tips

For better training performance:

# Enable in-memory caching for small datasets
dataset = BGEM3Dataset(
    data_path="train.jsonl",
    tokenizer=tokenizer,
    cache_in_memory=True,      # Pre-tokenize and cache in memory
    max_cache_size=100000      # Maximum samples to cache
)

For more optimization options, run: python -m data.optimization --help


🏗 Architecture

bge_finetune/
├── config.toml                  # Main configuration file
├── requirements.txt             # Python dependencies
├── README.md                    # Project documentation
├── __init__.py                  # Package initialization
│
├── config/                      # Configuration management
│   ├── __init__.py
│   ├── base_config.py           # Base configuration class with validation
│   ├── m3_config.py             # BGE-M3 specific configuration
│   └── reranker_config.py       # BGE-Reranker specific configuration
│
├── data/                        # Data processing and management
│   ├── __init__.py
│   ├── dataset.py               # Dataset classes (Contrastive, Reranker, Joint)
│   ├── preprocessing.py         # Format conversion and validation
│   ├── augmentation.py          # Query/passage augmentation strategies
│   ├── hard_negative_mining.py  # ANCE-style hard negative mining
│   └── optimization.py          # Data optimization utilities
│
├── models/                      # Model implementations
│   ├── __init__.py
│   ├── bge_m3.py                # BGE-M3 wrapper with multi-functionality
│   ├── bge_reranker.py          # BGE-Reranker cross-encoder wrapper
│   └── losses.py                # Loss functions (InfoNCE, ListwiseCE, etc.)
│
├── training/                    # Training logic
│   ├── __init__.py
│   ├── trainer.py               # Base trainer with error recovery
│   ├── m3_trainer.py            # BGE-M3 specialized trainer
│   ├── reranker_trainer.py      # Reranker specialized trainer
│   └── joint_trainer.py         # RocketQAv2-style joint training
│
├── evaluation/                  # Evaluation and metrics
│   ├── __init__.py
│   ├── evaluator.py             # Evaluation pipeline
│   ├── metrics.py               # Metric implementations (MRR, NDCG, etc.)
│   └── reporting.py             # Result visualization and reporting
│
├── utils/                       # Utilities and helpers
│   ├── __init__.py
│   ├── logging.py               # Distributed-aware logging with tqdm
│   ├── checkpoint.py            # Atomic checkpoint save/load
│   ├── distributed.py           # Distributed training utilities
│   ├── config_loader.py         # TOML config loading with precedence
│   ├── memory_utils.py          # Memory optimization utilities
│   └── ascend_support.py        # Huawei Ascend NPU support
│
├── scripts/                     # Entry point scripts
│   ├── __init__.py
│   ├── train_m3.py              # Train BGE-M3 embeddings
│   ├── train_reranker.py        # Train BGE-Reranker
│   ├── train_joint.py           # Joint training (RocketQAv2)
│   ├── evaluate.py              # Comprehensive evaluation
│   ├── benchmark.py             # Performance benchmarking
│   ├── compare_retriever.py     # Retriever comparison
│   ├── compare_reranker.py      # Reranker comparison
│   ├── test_installation.py     # Environment verification
│   └── setup.py                 # Initial setup script
│
├── cache/                       # Cache directory (auto-created)
│   ├── models/                  # Downloaded model cache
│   └── data/                    # Preprocessed data cache
│
├── logs/                        # Logging directory (auto-created)
│   ├── training/                # Training logs
│   ├── evaluation/              # Evaluation results
│   └── errors.log               # Error logs
│
├── output/                      # Training outputs (auto-created)
│   ├── bge-m3-finetuned/        # Fine-tuned models
│   ├── bge-reranker-finetuned/  # Fine-tuned rerankers
│   └── checkpoints/             # Training checkpoints
│
└── docs/                        # Additional documentation
    ├── API.md                   # API reference
    ├── CONTRIBUTING.md          # Contribution guidelines
    └── USAGE_GUIDE.md           # Detailed usage guide

Module Descriptions

Core Modules

  • config/: Hierarchical configuration system with TOML support and validation

    • Automatic type checking and value validation
    • Clear precedence rules (CLI > Model > Hardware > Training > Defaults)
    • Dynamic configuration based on hardware detection
  • data/: Comprehensive data handling

    • Format detection and conversion (JSONL, JSON, CSV, TSV, XML, DuReader, MS MARCO)
    • Streaming support for large datasets
    • Built-in augmentation with LLM and local model support
    • ANCE-style hard negative mining with FAISS integration
  • models/: Model wrappers and loss functions

    • BGE-M3: Dense, sparse (SPLADE), and multi-vector (ColBERT) support
    • BGE-Reranker: Cross-encoder with knowledge distillation
    • Loss functions: InfoNCE (τ=0.02), ListwiseCE, RankNet, LambdaRank
  • training/: Robust training infrastructure

    • Error recovery with configurable failure thresholds
    • Automatic mixed precision (fp16/bf16) based on hardware
    • Gradient checkpointing and memory optimization
    • Distributed training with automatic backend selection
  • evaluation/: Comprehensive evaluation suite

    • Metrics: Recall@K, MRR@K, NDCG@K, MAP, Success Rate, Coverage
    • Graded relevance support
    • Statistical significance testing
    • Performance benchmarking
  • utils/: Production-ready utilities

    • TqdmLoggingHandler for progress tracking
    • Atomic checkpoint saving with backup
    • Automatic distributed backend detection (NCCL/HCCL/Gloo)
    • Full Ascend NPU support module

Entry Points

  • Training Scripts:

    • train_m3.py: BGE-M3 training with all features
    • train_reranker.py: Reranker training with distillation
    • train_joint.py: RocketQAv2-style joint training
  • Evaluation Scripts:

    • evaluate.py: Full evaluation with all metrics
    • benchmark.py: Performance profiling
    • compare_*.py: Side-by-side model comparison
  • Utility Scripts:

    • test_installation.py: Verify environment setup
    • setup.py: Download models and prepare environment

Key Features by Module

  1. Automatic Hardware Detection

    • CUDA capability checking for optimal dtype (fp16/bf16)
    • Ascend NPU detection and configuration
    • CPU fallback for development
  2. Robust Error Handling

    • Training continues despite batch failures
    • Automatic checkpoint recovery
    • Comprehensive error logging
  3. Memory Efficiency

    • Gradient checkpointing
    • Periodic cache clearing
    • Streaming data loading
    • Mixed precision training
  4. Production Features

    • Atomic checkpoint saves
    • Distributed training support
    • Comprehensive logging
    • Configuration validation

🔧 Troubleshooting & FAQ

Q: My GPU is not detected or torch reports CUDA version mismatch.

  • Ensure you installed torch >=2.7.1 with CUDA 12.8+ support (see installation section).
  • Run nvidia-smi to check your driver and CUDA version.

Q: How do I enable Ascend NPU support?

  • Install torch_npu and set device_type = "npu" in [hardware] or [ascend] in config.toml.
  • See Ascend documentation for details.

Q: How do config parameters interact?

  • [hardware] > [training] for overlapping parameters.
  • [m3]/[reranker] > [training]/[hardware] for their models.
  • See config class docstrings and the config.toml for details.

Q: How do I run tests?

  • Install pytest and run pytest tests/ (if tests are present).
  • Or run python scripts/test_installation.py for a full environment check.

📚 References

  1. BGE M3-Embedding: arXiv:2402.03216
  2. RocketQAv2: ACL Anthology
  3. ANCE: arXiv:2007.00808
  4. ColBERT: SIGIR 2020
  5. SPLADE: arXiv:2109.10086

🔮 Roadmap

  • Knowledge Graph integration
  • Multi-modal retrieval support
  • Additional language support
  • Distributed training optimization
  • Model quantization for deployment
  • Benchmarking on standard datasets