Reviewed-on: #1
BGE Fine-tuning for Chinese Book Retrieval
A comprehensive, production-ready implementation for fine-tuning BAAI BGE-M3 embedding model and BGE-Reranker cross-encoder model for optimal retrieval performance on Chinese book content, with English query support.
⚡️ Requirements & Compatibility
- Python >= 3.10
- CUDA >= 12.8 (for NVIDIA RTX 5090 and similar)
- PyTorch >= 2.2.0 (Use nightly build for CUDA 12.8+ support)
- See requirements.txt for all dependencies
- Ascend NPU support: Install
torch_npuand set[hardware]/[ascend]config as needed
Hardware Support
- NVIDIA GPUs: RTX 5090, A100, H100 (CUDA 12.8+)
- Huawei Ascend NPUs: Ascend 910, 910B (with torch_npu)
- CPU: Fallback support for development/testing
🚀 Features
Core Capabilities
- State-of-the-art Techniques: InfoNCE loss with proper temperature scaling, self-knowledge distillation, ANCE-style hard negative mining, RocketQAv2 joint training
- Multi-functionality: Dense, sparse (SPLADE-style), and multi-vector (ColBERT-style) retrieval
- Cross-lingual Support: Chinese primary with English query compatibility
- Platform Compatibility: NVIDIA CUDA and Huawei Ascend NPU support with automatic detection
- Comprehensive Evaluation: Built-in metrics (Recall@k, MRR@k, NDCG@k, MAP) with additional success rate and coverage metrics
- Production Ready: Robust error handling, atomic checkpointing, comprehensive logging, and distributed training
Recent Enhancements
- ✅ Enhanced Error Recovery: Training continues despite batch failures with configurable thresholds
- ✅ Atomic Checkpoint Saving: Prevents corruption with temporary directories and backup mechanisms
- ✅ Improved Memory Management: Periodic cache clearing and gradient checkpointing for large models
- ✅ Configuration Precedence: Clear hierarchy (CLI > Model-specific > Hardware > Training > Defaults)
- ✅ Unified AMP Imports: Consistent torch.cuda.amp usage across all modules
- ✅ Comprehensive Metrics: Fixed MRR/NDCG implementations with edge case handling
- ✅ Ascend NPU Backend: Complete support module with device detection and optimization
- ✅ Enhanced Dataset Validation: Robust format detection with helpful error messages
📋 Technical Overview
BGE-M3 (Embedding Model)
- InfoNCE Loss with temperature τ=0.02 (verified from BGE-M3 paper)
- Proper Loss Normalization: Fixed cross-entropy calculation in listwise loss
- Self-Knowledge Distillation combining dense, sparse, and multi-vector signals
- ANCE-style Hard Negative Mining with asynchronous ANN index updates
- Multi-granularity Support up to 8192 tokens
- Unified Training for all three functionalities
- Gradient Checkpointing for memory-efficient training on large batches
BGE-Reranker (Cross-encoder)
- Listwise Cross-Entropy loss (fixed implementation, not KL divergence)
- Dynamic Listwise Distillation following RocketQAv2 approach
- Hard Negative Selection from retriever results
- Knowledge Distillation from teacher models
- Mixed Precision Training with bfloat16 support for RTX 5090
Joint Training (RocketQAv2)
- Mutual Distillation between retriever and reranker
- Dynamic Hard Negative Mining using evolving models
- Hybrid Data Augmentation strategies
- Error-resilient Training Loop with failure recovery
🛠 Installation
Hardware Requirements
- For RTX 5090: Python 3.10+, CUDA >= 12.8, Driver >= 545.x
- For Ascend NPU: Python 3.10+, CANN >= 7.0, torch_npu
- Memory: 24GB+ GPU memory recommended for full model training
Step 1: Install PyTorch (CUDA 12.8+ for RTX 5090)
# Install PyTorch nightly with CUDA 12.8 support
pip3 install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu128
Step 2: Install Dependencies
pip install -r requirements.txt
Step 3: (Optional) Platform-specific Setup
For Huawei Ascend NPU:
# Install torch_npu (follow official Huawei documentation)
pip install torch_npu
# Set environment variables
export ASCEND_RT_VISIBLE_DEVICES=0,1,2,3 # Your NPU devices
For Distributed Training:
# No additional setup needed - automatic backend selection:
# - NCCL for NVIDIA GPUs
# - HCCL for Ascend NPUs
# - Gloo for CPU
Step 4: Download Models (Optional - auto-downloads if not present)
# Models will auto-download on first use, or manually download:
huggingface-cli download BAAI/bge-m3 --local-dir ./models/bge-m3
huggingface-cli download BAAI/bge-reranker-base --local-dir ./models/bge-reranker-base
🔧 Configuration
Configuration System Overview
The project uses a hierarchical configuration system with clear precedence rules and automatic validation. Configuration is managed through config.toml with programmatic overrides via command-line arguments.
Configuration Precedence Order (Highest to Lowest)
- Command-line arguments - Override any config file settings
- Model-specific sections -
[m3]for BGE-M3,[reranker]for BGE-Reranker - Hardware-specific sections -
[hardware],[ascend]for device-specific settings - Training defaults -
[training]section for general defaults - Dataclass field defaults - Built-in defaults in code
Example precedence in action:
# config.toml
[training]
default_batch_size = 4
[hardware]
default_batch_size = 8 # Overrides [training] for hardware optimization
[m3]
per_device_train_batch_size = 16 # Overrides both [hardware] and [training] for M3 model
Main Configuration File (config.toml)
# API Configuration for augmentation and external services
[api]
base_url = "https://api.deepseek.com/v1"
api_key = "your-api-key-here" # Replace with your actual API key
model = "deepseek-chat"
max_retries = 3
timeout = 30
temperature = 0.7
# General training defaults
[training]
default_batch_size = 4
default_learning_rate = 1e-5
default_num_epochs = 3
default_warmup_ratio = 0.1
default_weight_decay = 0.01
gradient_accumulation_steps = 4
max_grad_norm = 1.0
logging_steps = 100
save_strategy = "epoch"
evaluation_strategy = "epoch"
load_best_model_at_end = true
metric_for_best_model = "eval_loss"
greater_is_better = false
save_total_limit = 3
fp16 = true
bf16 = false # Enable for A100/RTX5090
gradient_checkpointing = false
dataloader_num_workers = 4
remove_unused_columns = true
label_names = ["labels"]
# Model paths with automatic download
[model_paths]
bge_m3 = "./models/bge-m3"
bge_reranker = "./models/bge-reranker-base"
cache_dir = "./cache/models"
# Data processing configuration
[data]
max_query_length = 64
max_passage_length = 512
max_seq_length = 512
validation_split = 0.1
random_seed = 42
preprocessing_num_workers = 4
overwrite_cache = false
pad_to_max_length = false
return_tensors = "pt"
# Hardware-specific optimizations
[hardware]
device_type = "cuda" # cuda, npu, cpu
mixed_precision = true
mixed_precision_dtype = "fp16" # fp16, bf16
enable_tf32 = true # For Ampere GPUs
cudnn_benchmark = true
cudnn_deterministic = false
per_device_train_batch_size = 8
per_device_eval_batch_size = 16
# Distributed training configuration
[distributed]
backend = "nccl" # nccl, gloo, hccl (auto-detected if not set)
timeout_minutes = 30
find_unused_parameters = false
gradient_as_bucket_view = true
static_graph = false
# Huawei Ascend NPU specific settings
[ascend]
device_type = "npu"
backend = "hccl"
mixed_precision = true
mixed_precision_dtype = "fp16"
visible_devices = "0,1,2,3"
enable_graph_mode = false
enable_profiling = false
profiling_start_step = 10
profiling_stop_step = 20
# BGE-M3 specific configuration
[m3]
model_name_or_path = "BAAI/bge-m3"
learning_rate = 1e-5
per_device_train_batch_size = 16
per_device_eval_batch_size = 32
num_train_epochs = 5
warmup_ratio = 0.1
temperature = 0.02 # InfoNCE temperature (from paper)
use_self_distill = true
self_distill_weight = 1.0
use_hard_negatives = true
num_hard_negatives = 15
negatives_cross_batch = true
pooling_method = "cls"
normalize_embeddings = true
use_sparse = true
use_colbert = true
colbert_dim = 32
# BGE-Reranker specific configuration
[reranker]
model_name_or_path = "BAAI/bge-reranker-base"
learning_rate = 6e-5
per_device_train_batch_size = 8
per_device_eval_batch_size = 16
num_train_epochs = 3
warmup_ratio = 0.06
train_group_size = 16
loss_type = "listwise_ce" # listwise_ce, ranknet, lambdarank
use_gradient_checkpointing = true
temperature = 1.0
use_distillation = false
teacher_model_path = "BAAI/bge-reranker-large"
distillation_weight = 0.5
# Data augmentation settings
[augmentation]
enable_query_augmentation = false
enable_passage_augmentation = false
augmentation_methods = ["paraphrase", "back_translation", "mask_and_fill"]
num_augmentations_per_sample = 2
augmentation_temperature = 0.8
api_augmentation = true
local_augmentation_model = "google/flan-t5-base"
# Hard negative mining configuration
[hard_negative_mining]
enable_mining = true
mining_strategy = "ance" # ance, random, bm25
num_hard_negatives = 15
negative_depth = 200 # How deep to search for negatives
update_interval = 1000 # Steps between mining updates
margin = 0.1
use_cross_batch_negatives = true
dedup_negatives = true
# FAISS index settings for mining
[faiss]
index_type = "IVF" # Flat, IVF, HNSW
index_params = {nlist = 100, nprobe = 10}
use_gpu = true
normalize_vectors = true
# Checkpoint and recovery settings
[checkpoint]
save_strategy = "epoch"
save_steps = 1000
save_total_limit = 3
save_on_each_node = false
resume_from_checkpoint = true
checkpoint_save_dir = "./checkpoints"
use_atomic_save = true # Prevents corruption
keep_backup = true
# Logging configuration
[logging]
logging_dir = "./logs"
logging_strategy = "steps"
logging_steps = 100
logging_first_step = true
report_to = ["tensorboard", "wandb"]
wandb_project = "bge-finetuning"
wandb_run_name = null # Auto-generated if null
disable_tqdm = false
tqdm_mininterval = 1
# Evaluation settings
[evaluation]
evaluation_strategy = "epoch"
eval_steps = 500
eval_batch_size = 32
metric_for_best_model = "eval_loss"
load_best_model_at_end = true
compute_metrics_fn = "accuracy_and_f1"
greater_is_better = false
eval_accumulation_steps = null
# Memory optimization
[memory]
gradient_checkpointing = true
gradient_checkpointing_kwargs = {use_reentrant = false}
optim = "adamw_torch" # adamw_torch, adamw_hf, sgd, adafactor
optim_args = null
clear_cache_interval = 100 # Steps between cache clearing
max_memory_mb = null # Max memory usage before clearing
# Error handling and recovery
[error_handling]
continue_on_error = true
max_failed_batches = 10
error_log_file = "./logs/errors.log"
save_on_error = true
retry_failed_batches = true
num_retries = 3
Command-Line Override Examples
# Override batch size and learning rate
python scripts/train_m3.py \
--per_device_train_batch_size 32 \
--learning_rate 2e-5
# Override hardware settings
python scripts/train_m3.py \
--device npu \
--mixed_precision_dtype bf16
# Override multiple config sections
python scripts/train_m3.py \
--config_file custom_config.toml \
--temperature 0.05 \
--num_hard_negatives 20 \
--gradient_checkpointing
Environment Variable Support
# Set API key via environment variable
export BGE_API_KEY="your-api-key"
# Override config file location
export BGE_CONFIG_FILE="/path/to/custom/config.toml"
# Set distributed training environment
export CUDA_VISIBLE_DEVICES="0,1,2,3"
export WORLD_SIZE=4
Configuration Validation
The configuration system automatically validates settings:
# In your code
from config.base_config import BaseConfig
config = BaseConfig.from_toml("config.toml")
# Automatic validation includes:
# - Type checking (int, float, str, bool, list, dict)
# - Value range validation (e.g., learning_rate > 0)
# - Compatibility checks (e.g., bf16 requires compatible hardware)
# - Missing required fields
# - Deprecated parameter warnings
Dynamic Configuration
# Programmatic configuration updates
from config.m3_config import M3Config
config = M3Config.from_toml("config.toml")
# Update based on hardware
if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8:
config.mixed_precision_dtype = "bf16"
config.per_device_train_batch_size *= 2
# Update based on dataset size
if dataset_size > 1000000:
config.gradient_accumulation_steps = 8
config.logging_steps = 500
Best Practices
- Start with defaults: The provided
config.tomlhas sensible defaults - Hardware-specific tuning: Use
[hardware]section for GPU/NPU optimization - Model-specific settings: Use
[m3]or[reranker]for model-specific params - Environment isolation: Use different config files for dev/staging/prod
- Version control: Track config.toml changes in git
- Secrets management: Use environment variables for API keys
🚀 Quick Start
Step 1: Prepare Your Data
Create your training data in JSONL format. Each line should be a valid JSON object.
For embedding models (BGE-M3):
{"query": "What is machine learning?", "pos": ["Machine learning is..."], "neg": ["The weather is..."]}
{"query": "How to cook pasta?", "pos": ["Boil water and add pasta..."], "neg": ["Machine learning..."]}
For reranker models:
{"query": "What is AI?", "passage": "AI is artificial intelligence...", "label": 1}
{"query": "What is AI?", "passage": "The weather today is sunny...", "label": 0}
Step 2: Fine-tune Your Model
# Basic training with automatic hardware detection
python scripts/train_m3.py \
--train_data data/train.jsonl \
--output_dir ./output/bge-m3-finetuned \
--num_train_epochs 3 \
--per_device_train_batch_size 4 \
--learning_rate 1e-5
# Advanced training with all features enabled
python scripts/train_m3.py \
--train_data data/train.jsonl \
--output_dir ./output/bge-m3-advanced \
--num_train_epochs 5 \
--per_device_train_batch_size 8 \
--learning_rate 1e-5 \
--use_self_distill \
--use_hard_negatives \
--gradient_checkpointing \
--bf16 \
--save_total_limit 3 \
--logging_steps 100 \
--evaluation_strategy "steps" \
--eval_steps 500 \
--metric_for_best_model "eval_loss" \
--load_best_model_at_end
3. Fine-tune BGE-Reranker (Cross-encoder)
# Basic reranker training
python scripts/train_reranker.py \
--train_data data/train.jsonl \
--output_dir ./output/bge-reranker-finetuned \
--learning_rate 6e-5 \
--train_group_size 16 \
--loss_type listwise_ce
# Advanced reranker training with knowledge distillation
python scripts/train_reranker.py \
--train_data data/train.jsonl \
--output_dir ./output/bge-reranker-advanced \
--learning_rate 6e-5 \
--train_group_size 16 \
--loss_type listwise_ce \
--use_distillation \
--teacher_model_path BAAI/bge-reranker-large \
--bf16 \
--gradient_checkpointing \
--save_strategy "epoch" \
--evaluation_strategy "epoch"
4. Joint Training (RocketQAv2-style)
# Joint training with mutual distillation
python scripts/train_joint.py \
--train_data data/train.jsonl \
--output_dir ./output/joint-training \
--use_dynamic_distillation \
--joint_loss_weight 0.3 \
--bf16 \
--gradient_checkpointing \
--max_grad_norm 1.0 \
--dataloader_num_workers 4
5. Multi-GPU / Distributed Training
# NVIDIA GPUs (automatic NCCL backend)
torchrun --nproc_per_node=4 scripts/train_m3.py \
--train_data data/train.jsonl \
--output_dir ./output/distributed \
--per_device_train_batch_size 16 \
--gradient_accumulation_steps 2 \
--bf16 \
--gradient_checkpointing
# Huawei Ascend NPUs (automatic HCCL backend)
# First, configure config.toml [ascend] section, then:
python scripts/train_m3.py \
--train_data data/train.jsonl \
--output_dir ./output/ascend \
--device npu
📊 Data Format
Canonical Format (Required Structure)
All data formats must contain these required fields:
query(string): The search query or questionpos(list of strings): Positive (relevant) passagesneg(list of strings): Negative (irrelevant) passages
Optional fields for advanced training:
pos_scores(list of floats): Relevance scores for positives (0.0-1.0)neg_scores(list of floats): Relevance scores for negatives (0.0-1.0)prompt(string): Custom prompt for encoding
Supported Input Formats
Our dataset loader automatically detects and validates the following formats:
1. JSONL (Recommended - Most Efficient)
{"query": "What is machine learning?", "pos": ["Machine learning is..."], "neg": ["Deep learning is..."]}
{"query": "深度学习的应用", "pos": ["深度学习在计算机视觉..."], "neg": ["机器学习是..."]}
2. JSON (Array Format)
[
{
"query": "What is BGE?",
"pos": ["BGE is a family of embedding models..."],
"neg": ["BERT is a transformer model..."],
"pos_scores": [1.0],
"neg_scores": [0.2]
}
]
3. CSV/TSV
query,pos,neg,pos_scores,neg_scores
"What is machine learning?","['ML is...', 'Machine learning refers to...']","['DL is...']","[1.0, 0.9]","[0.1]"
4. DuReader Format
{
"data": [{
"question": "什么是机器学习?",
"answers": ["机器学习是..."],
"documents": [{
"content": "深度学习是...",
"is_relevant": false
}]
}]
}
5. MS MARCO Format (TSV)
query_id query doc_id document relevance
1001 What is ML? 2001 Machine learning is... 1
1001 What is ML? 2002 Deep learning is... 0
6. XML Format (TREC-style)
<topics>
<topic>
<query>Information retrieval basics</query>
<positive>IR is the process of obtaining...</positive>
<negative>Database management is...</negative>
</topic>
</topics>
Data Validation & Error Handling
The dataset loader performs comprehensive validation:
- Format Detection: Automatic format inference with helpful error messages
- Field Validation: Ensures all required fields are present and properly typed
- Content Validation: Checks for empty queries/passages, invalid scores
- Encoding Validation: Handles various text encodings (UTF-8, GB2312, etc.)
- Size Validation: Warns about extremely long sequences
Example validation output:
Dataset validation completed:
✓ Format: JSONL
✓ Total samples: 10,000
✓ Valid samples: 9,998
✗ Invalid samples: 2
- Line 1523: Empty query field
- Line 7891: Missing 'neg' field
✓ Encoding: UTF-8
⚠ Warnings:
- 15 samples have passages > 8192 tokens (will be truncated)
Converting Data Formats
Use the built-in preprocessor to convert between formats:
from data.preprocessing import DataPreprocessor
preprocessor = DataPreprocessor()
# Convert any format to canonical JSONL
preprocessor.convert_to_canonical(
input_file="data/msmarco.tsv",
output_file="data/train.jsonl",
input_format="msmarco" # or "auto" for automatic detection
)
# Batch conversion with validation
preprocessor.batch_convert(
input_files=["data/*.csv", "data/*.json"],
output_dir="data/processed/",
validate=True,
max_workers=4
)
Working with Book/Document Collections
For training on books or large documents:
- Split documents into passages:
from data.preprocessing import DocumentSplitter
splitter = DocumentSplitter(
chunk_size=512,
overlap=50,
split_by="sentence" # or "paragraph", "sliding_window"
)
passages = splitter.split_document("path/to/book.txt")
- Generate training pairs:
from data.preprocessing import TrainingPairGenerator
generator = TrainingPairGenerator()
# Generate query-passage pairs with hard negatives
training_data = generator.generate_from_passages(
passages=passages,
queries_per_passage=2,
hard_negative_ratio=0.7,
use_bm25=True
)
# Save in canonical format
generator.save_jsonl(training_data, "data/book_train.jsonl")
- Use data optimization:
# Optimize data for BGE-M3 training
python -m data.optimization bge-m3 data/book_train.jsonl \
--output_dir ./optimized_data \
--hard_negative_ratio 0.8 \
--augment_queries
Memory-Efficient Data Loading
For large datasets, use streaming mode:
# In your training script
dataset = ContrastiveDataset(
data_file="data/large_dataset.jsonl",
tokenizer=tokenizer,
streaming=True, # Enable streaming
buffer_size=10000 # Buffer size for shuffling
)
📖 Training with Book .txt Files
- You cannot train directly on a raw
.txtbook file. - Convert your book to a supported format:
- Split the book into passages (e.g., by chapter or paragraph).
- Create queries and annotate which passages are positive (relevant) and negative (irrelevant) for each query.
- Example:
{"query": "What happens in Chapter 1?", "pos": ["Chapter 1: ..."], "neg": ["Chapter 2: ..."]}
- Use the provided
DataPreprocessorto convert CSV/TSV/JSON/XML to JSONL.
Note: All formats must be convertible to the canonical JSONL structure above. Use the provided DataPreprocessor to convert between formats.
⚙️ Configuration
Parameter Precedence
[hardware]overrides[training]for overlapping parameters (e.g., batch size).[m3]and[reranker]override[training]/[hardware]for their respective models.- For Ascend NPU, use the
[ascend]section and setdevice_type = "npu".
Hardware Compatibility
- NVIDIA RTX 5090: Use torch >=2.7.1 and CUDA 12.8+ (see above).
- Huawei Ascend NPU: Install
torch_npu, set[hardware]or[ascend]config, and ensure your environment matches Ascend requirements.
API Configuration (config.toml)
[api]
base_url = "https://api.deepseek.com/v1"
api_key = "your-api-key-here"
model = "deepseek-chat"
[training]
default_batch_size = 4
default_learning_rate = 1e-5
[model_paths]
bge_m3 = "./models/bge-m3"
bge_reranker = "./models/bge-reranker-base"
Configuration Parameter Reference
[api]
| Parameter | Type | Default | Description |
|---|---|---|---|
| base_url | str | "https://api.deepseek..." | API endpoint for augmentation |
| api_key | str | "your-api-key-here" | API key for external services |
| model | str | "deepseek-chat" | Model name for API |
| max_retries | int | 3 | Max API retries |
| timeout | int | 30 | API timeout (seconds) |
| temperature | float | 0.7 | API sampling temperature |
[training]
| Parameter | Type | Default | Description |
|---|---|---|---|
| default_batch_size | int | 4 | Default batch size for training |
| default_learning_rate | float | 1e-5 | Default learning rate |
| default_num_epochs | int | 3 | Default number of epochs |
| default_warmup_ratio | float | 0.1 | Warmup steps ratio |
| default_weight_decay | float | 0.01 | Weight decay for optimizer |
| gradient_accumulation_steps | int | 4 | Steps to accumulate gradients |
[model_paths]
| Parameter | Type | Default | Description |
|---|---|---|---|
| bge_m3 | str | "./models/bge-m3" | Path to BGE-M3 model |
| bge_reranker | str | "./models/bge-reranker-base" | Path to BGE-Reranker model |
| cache_dir | str | "./cache/models" | Model cache directory |
[data]
| Parameter | Type | Default | Description |
|---|---|---|---|
| max_query_length | int | 64 | Max query length |
| max_passage_length | int | 512 | Max passage length |
| max_seq_length | int | 512 | Max sequence length |
| validation_split | float | 0.1 | Validation split ratio |
| random_seed | int | 42 | Random seed for reproducibility |
[augmentation]
| Parameter | Type | Default | Description |
|---|---|---|---|
| enable_query_augmentation | bool | false | Enable query augmentation |
| enable_passage_augmentation | bool | false | Enable passage augmentation |
| augmentation_methods | list | ["paraphrase", "back_translation"] | Augmentation methods to use |
| num_augmentations_per_sample | int | 2 | Number of augmentations per sample |
| augmentation_temperature | float | 0.8 | Temperature for augmentation |
| api_first | bool | true | Prefer LLM API for all augmentation if enabled |
[hard_negative_mining]
| Parameter | Type | Default | Description |
|---|---|---|---|
| enable_mining | bool | true | Enable hard negative mining |
| num_hard_negatives | int | 15 | Number of hard negatives to mine |
| negative_range_min | int | 10 | Min negative index for mining |
| negative_range_max | int | 100 | Max negative index for mining |
| margin | float | 0.1 | Margin for ANCE mining |
| index_refresh_interval | int | 1000 | Steps between index refreshes |
| index_type | str | "IVF" | FAISS index type |
| nlist | int | 100 | Number of clusters for IVF |
[m3] and [reranker]
- All model-specific parameters are now documented in the config classes and code docstrings. See
config/m3_config.pyandconfig/reranker_config.pyfor full details.
⚡️ Distributed Training
Overview
The project supports distributed training across multiple GPUs/NPUs with automatic backend detection, error recovery, and optimized communication strategies. The distributed training system is fully configuration-driven and handles both data-parallel and model-parallel strategies.
Automatic Backend Detection
The system automatically selects the optimal backend based on your hardware:
- NVIDIA GPUs: NCCL (fastest for GPU-to-GPU communication)
- Huawei Ascend NPUs: HCCL (optimized for Ascend architecture)
- CPU or Mixed: Gloo (fallback for heterogeneous setups)
Basic Usage
Single-Node Multi-GPU (NVIDIA)
# Automatic detection and setup
torchrun --nproc_per_node=4 scripts/train_m3.py \
--train_data data/train.jsonl \
--output_dir ./output/distributed
# With specific GPU selection
CUDA_VISIBLE_DEVICES=0,1,2,3 torchrun --nproc_per_node=4 scripts/train_m3.py \
--train_data data/train.jsonl \
--output_dir ./output/distributed \
--gradient_accumulation_steps 2
Multi-Node Training (NVIDIA)
# On node 0 (master)
torchrun --nproc_per_node=4 --nnodes=2 --node_rank=0 \
--master_addr=192.168.1.1 --master_port=29500 \
scripts/train_m3.py --train_data data/train.jsonl
# On node 1
torchrun --nproc_per_node=4 --nnodes=2 --node_rank=1 \
--master_addr=192.168.1.1 --master_port=29500 \
scripts/train_m3.py --train_data data/train.jsonl
Huawei Ascend NPU Training
# Single-node multi-NPU
python scripts/train_m3.py \
--train_data data/train.jsonl \
--output_dir ./output/ascend \
--device npu
# With specific NPU selection
export ASCEND_RT_VISIBLE_DEVICES=0,1,2,3
python scripts/train_m3.py \
--train_data data/train.jsonl \
--output_dir ./output/ascend
Configuration Options
```toml
config.toml
[distributed] backend = "auto" # auto, nccl, gloo, hccl timeout_minutes = 30 find_unused_parameters = false gradient_as_bucket_view = true broadcast_buffers = true bucket_cap_mb = 25 use_distributed_optimizer = false
Performance optimizations
[distributed.optimization] overlap_comm = true gradient_compression = false compression_type = "none" # none, powerSGD, fp16 all_reduce_fusion = true fusion_threshold_mb = 64
Fault tolerance
[distributed.fault_tolerance] enable_checkpointing = true checkpoint_interval = 1000 restart_on_failure = true max_restart_attempts = 3 heartbeat_interval = 60
### Advanced Features
#### 1. Gradient Accumulation with Distributed Training
```python
# Effective batch size = per_device_batch * gradient_accumulation * num_gpus
# Example: 8 * 4 * 4 = 128 effective batch size
python scripts/train_m3.py \
--per_device_train_batch_size 8 \
--gradient_accumulation_steps 4 \
--distributed
2. Mixed Precision with Distributed
# Automatic mixed precision for faster training
torchrun --nproc_per_node=4 scripts/train_m3.py \
--bf16 \
--tf32 # Additional optimization for Ampere GPUs
3. DeepSpeed Integration
# With DeepSpeed for extreme scale
deepspeed scripts/train_m3.py \
--deepspeed deepspeed_config.json \
--train_data data/train.jsonl
Example DeepSpeed config:
{
"train_batch_size": "auto",
"gradient_accumulation_steps": "auto",
"fp16": {
"enabled": "auto",
"loss_scale": 0,
"loss_scale_window": 1000
},
"zero_optimization": {
"stage": 2,
"offload_optimizer": {
"device": "cpu",
"pin_memory": true
},
"allgather_partitions": true,
"allgather_bucket_size": 2e8,
"overlap_comm": true,
"reduce_scatter": true,
"reduce_bucket_size": 2e8,
"contiguous_gradients": true
}
}
4. Fault Tolerance and Recovery
The distributed training system includes robust error handling:
# Training automatically recovers from:
# - Temporary network failures
# - GPU memory errors (with retry)
# - Node failures (with checkpoint recovery)
# - Communication timeouts
5. Performance Monitoring
# Enable distributed profiling
python scripts/train_m3.py \
--enable_profiling \
--profiling_start_step 100 \
--profiling_end_step 200
# Monitor with tensorboard
tensorboard --logdir ./logs --bind_all
Best Practices
-
Data Loading
# Use distributed sampler (automatic) # Set num_workers based on CPUs --dataloader_num_workers 4 -
Batch Size Scaling
# Scale learning rate with batch size # LR = base_lr * (total_batch_size / base_batch_size) --learning_rate 1e-5 * (num_gpus * batch_size / 16) -
Communication Optimization
# Reduce communication overhead --gradient_accumulation_steps 4 --broadcast_buffers false # If not needed -
Memory Optimization
# For large models --gradient_checkpointing --zero_stage 2 # With DeepSpeed
Troubleshooting
NCCL Errors
# Debug NCCL issues
export NCCL_DEBUG=INFO
export NCCL_DEBUG_SUBSYS=ALL
# Common fixes
export NCCL_IB_DISABLE=1 # Disable InfiniBand
export NCCL_P2P_DISABLE=1 # Disable P2P
Hanging Training
# Set timeout for debugging
export NCCL_TIMEOUT=300 # 5 minutes
export NCCL_ASYNC_ERROR_HANDLING=1
Memory Issues
# Clear cache periodically
--clear_cache_interval 100
# Use gradient checkpointing
--gradient_checkpointing
Performance Benchmarks
Typical speedup with distributed training:
- 2 GPUs: 1.8-1.9x speedup
- 4 GPUs: 3.5-3.8x speedup
- 8 GPUs: 6.5-7.5x speedup
Factors affecting performance:
- Model size (larger models scale better)
- Batch size (larger batches reduce communication overhead)
- Network bandwidth (especially for multi-node)
- Data loading speed
📊 Evaluation Metrics
Overview
The evaluation system provides comprehensive metrics for both retrieval and reranking tasks, with support for graded relevance, statistical significance testing, and performance benchmarking.
Core Metrics
Retrieval Metrics
-
Recall@K - Fraction of relevant documents retrieved in top K
recall = |retrieved_relevant| / |total_relevant| -
Precision@K - Fraction of retrieved documents that are relevant
precision = |retrieved_relevant| / K -
Mean Average Precision (MAP) - Average precision at each relevant position
MAP = mean([AP(q) for q in queries]) where AP = sum([P@k * rel(k) for k in positions]) / |relevant| -
Mean Reciprocal Rank (MRR@K) - Average reciprocal of first relevant result rank
MRR = mean([1/rank_of_first_relevant for q in queries]) -
Normalized Discounted Cumulative Gain (NDCG@K) - Position-weighted relevance score
DCG@K = sum([rel(i) / log2(i + 1) for i in 1..K]) NDCG@K = DCG@K / IDCG@K -
Success Rate@K - Fraction of queries with at least one relevant result in top K
success_rate = |queries_with_relevant_in_topK| / |total_queries| -
Coverage@K - Fraction of all relevant documents found across all queries
coverage = |unique_relevant_retrieved| / |total_unique_relevant|
Reranking Metrics
- Accuracy@1 - Top-1 ranking accuracy
- MRR - Mean reciprocal rank for reranked results
- NDCG@K - Ranking quality with graded relevance
- ERR (Expected Reciprocal Rank) - Cascade model metric
Usage Examples
Basic Evaluation
# Evaluate retriever
python scripts/evaluate.py \
--eval_type retriever \
--model_path ./output/bge-m3-finetuned \
--eval_data data/test.jsonl \
--k_values 1 5 10 20 50 \
--output_file results/retriever_metrics.json
# Evaluate reranker
python scripts/evaluate.py \
--eval_type reranker \
--model_path ./output/bge-reranker-finetuned \
--eval_data data/test.jsonl \
--output_file results/reranker_metrics.json
# Evaluate full pipeline
python scripts/evaluate.py \
--eval_type pipeline \
--retriever_path ./output/bge-m3-finetuned \
--reranker_path ./output/bge-reranker-finetuned \
--eval_data data/test.jsonl \
--retrieval_top_k 50 \
--rerank_top_k 10
Advanced Evaluation with Graded Relevance
# Data format with relevance scores
{
"query": "machine learning basics",
"documents": [
{"text": "Introduction to ML", "relevance": 2}, # Highly relevant
{"text": "Deep learning guide", "relevance": 1}, # Somewhat relevant
{"text": "Database systems", "relevance": 0} # Not relevant
]
}
# Evaluate with graded relevance
python scripts/evaluate.py \
--eval_type retriever \
--model_path ./output/bge-m3-finetuned \
--eval_data data/test_graded.jsonl \
--use_graded_relevance \
--relevance_threshold 1 \
--ndcg_k_values 5 10 20
Statistical Significance Testing
# Compare two models with significance testing
python scripts/compare_retriever.py \
--baseline_model_path BAAI/bge-m3 \
--finetuned_model_path ./output/bge-m3-finetuned \
--data_path data/test.jsonl \
--num_bootstrap_samples 1000 \
--confidence_level 0.95
# Output includes:
# - Metric comparisons with confidence intervals
# - Statistical significance (p-values)
# - Effect sizes (Cohen's d)
Metric Computation Details
Handling Edge Cases
-
Empty Results
- Recall: 0 if no relevant documents exist, else 0
- Precision: 0 for empty results
- MRR: 0 if no relevant document found
- NDCG: 0 for empty results
-
Tied Scores
- Uses stable sorting to maintain consistency
- Averages metrics across tied positions
-
Missing Relevance Labels
- Binary relevance: treats unlabeled as non-relevant
- Graded relevance: requires explicit labels
Efficiency Optimizations
# Batch evaluation for large datasets
evaluator = Evaluator(
model=model,
batch_size=128,
use_fp16=True,
device="cuda"
)
# Streaming evaluation for memory efficiency
results = evaluator.evaluate_streaming(
data_path="large_dataset.jsonl",
chunk_size=10000
)
Performance Benchmarking
# Benchmark inference speed
python scripts/benchmark.py \
--model_type retriever \
--model_path ./output/bge-m3-finetuned \
--data_path data/test.jsonl \
--batch_sizes 1 8 16 32 64 \
--num_warmup_steps 10 \
--num_benchmark_steps 100
# Output includes:
# - Throughput (queries/second)
# - Latency percentiles (p50, p90, p95, p99)
# - Memory usage
# - Hardware utilization
Custom Metrics
from evaluation.metrics import BaseMetric
class CustomMetric(BaseMetric):
def compute(self, predictions, references, k=10):
# Your custom metric logic
return score
# Register custom metric
from evaluation.evaluator import Evaluator
evaluator = Evaluator(model=model)
evaluator.add_metric("custom_metric", CustomMetric())
Visualization and Reporting
# Generate evaluation report
from evaluation.reporting import EvaluationReport
report = EvaluationReport(results)
report.generate_html("evaluation_report.html")
report.generate_latex("evaluation_report.tex")
report.plot_metrics("plots/")
# Example plots:
# - Recall@K curves
# - Precision-Recall curves
# - NDCG@K comparison
# - Query-level performance distribution
Best Practices
-
Use Appropriate K Values
- Small K (1, 5, 10) for user-facing metrics
- Large K (50, 100) for reranking scenarios
-
Consider Query Types
- Navigational queries: Focus on MRR, Success@1
- Informational queries: Focus on NDCG, Recall@K
-
Statistical Rigor
- Always use test sets with 1000+ queries
- Report confidence intervals
- Use paired tests for model comparison
-
Graded Relevance
- Use NDCG for nuanced evaluation
- Define clear relevance guidelines
- Consider query-specific relevance scales
📈 Model Evaluation & Benchmarking
1. Detailed Accuracy Evaluation
Use scripts/evaluate.py for comprehensive accuracy evaluation of retriever, reranker, or the full pipeline. Supports all standard metrics (MRR, Recall@k, MAP, NDCG, etc.) and baseline comparison.
Example:
python scripts/evaluate.py --retriever_model path/to/finetuned --eval_data path/to/data.jsonl --eval_retriever --k_values 1 5 10
python scripts/evaluate.py --reranker_model path/to/finetuned --eval_data path/to/data.jsonl --eval_reranker
python scripts/evaluate.py --retriever_model path/to/finetuned --reranker_model path/to/finetuned --eval_data path/to/data.jsonl --eval_pipeline --compare_baseline --base_retriever path/to/baseline
2. Quick Performance Benchmarking
Use scripts/benchmark.py for fast, single-model performance profiling (throughput, latency, memory). This script does not report accuracy metrics.
Example:
python scripts/benchmark.py --model_type retriever --model_path path/to/model --data_path path/to/data.jsonl --batch_size 16 --device cuda
python scripts/benchmark.py --model_type reranker --model_path path/to/model --data_path path/to/data.jsonl --batch_size 16 --device cuda
3. Fine-tuned vs. Baseline Comparison (Accuracy & Performance)
Use scripts/compare_retriever.py and scripts/compare_reranker.py to compare a fine-tuned model against a baseline, reporting both accuracy and performance metrics side by side, including the delta (improvement) for each metric.
Example:
python scripts/compare_retriever.py --finetuned_model_path path/to/finetuned --baseline_model_path path/to/baseline --data_path path/to/data.jsonl --batch_size 16 --device cuda
python scripts/compare_reranker.py --finetuned_model_path path/to/finetuned --baseline_model_path path/to/baseline --data_path path/to/data.jsonl --batch_size 16 --device cuda
- Results are saved to a file and include a table of metrics for both models and the improvement (delta).
🖥️ Ascend NPU Section
- Install:
pip install torch_npu(see Ascend docs) - Edit
config.toml:[ascend] device_type = "npu" backend = "hccl" mixed_precision = true visible_devices = "0,1,2,3" - Launch training:
python3.10 scripts/train_m3.py --train_data data/train.jsonl --output_dir ./output/npu - Distributed: Use Ascend's distributed launcher or set
backend = "hccl"in[ascend]. - All device/backend selection is config-driven.
🚀 Performance Optimization
Multi-GPU Training
torchrun --nproc_per_node=2 scripts/train_m3.py \
--train_data data/train.jsonl \
--output_dir ./output/distributed \
--fp16 \
--gradient_checkpointing
Memory Optimization
- Gradient Checkpointing: Reduces memory usage
- Mixed Precision: FP16 training for RTX 5090
- Gradient Accumulation: Effective larger batch sizes
- DeepSpeed Integration: For large-scale training
Huawei Ascend Compatibility
# For Ascend deployment
import torch_npu
device = torch.device("npu")
model.to(device)
📝 Recent Changes & Improvements
✅ Simplified Data Pipeline (Latest)
- Removed unnecessary PKL conversion pipeline: The
data/optimization.pymodule has been removed- Performance benefits came from pre-tokenization, not PKL format
- PKL files caused memory issues and added complexity without real benefits
- Added in-memory caching option: For small datasets (< 100k samples), enable
cache_in_memory=True - Kept JSONL as primary format: Industry standard, human-readable, and efficient
- All core functionality intact: Training, evaluation, and model features unchanged
🚀 Performance Tips
For better training performance:
# Enable in-memory caching for small datasets
dataset = BGEM3Dataset(
data_path="train.jsonl",
tokenizer=tokenizer,
cache_in_memory=True, # Pre-tokenize and cache in memory
max_cache_size=100000 # Maximum samples to cache
)
For more optimization options, run: python -m data.optimization --help
🏗 Architecture
bge_finetune/
├── config.toml # Main configuration file
├── requirements.txt # Python dependencies
├── README.md # Project documentation
├── __init__.py # Package initialization
│
├── config/ # Configuration management
│ ├── __init__.py
│ ├── base_config.py # Base configuration class with validation
│ ├── m3_config.py # BGE-M3 specific configuration
│ └── reranker_config.py # BGE-Reranker specific configuration
│
├── data/ # Data processing and management
│ ├── __init__.py
│ ├── dataset.py # Dataset classes (Contrastive, Reranker, Joint)
│ ├── preprocessing.py # Format conversion and validation
│ ├── augmentation.py # Query/passage augmentation strategies
│ ├── hard_negative_mining.py # ANCE-style hard negative mining
│ └── optimization.py # Data optimization utilities
│
├── models/ # Model implementations
│ ├── __init__.py
│ ├── bge_m3.py # BGE-M3 wrapper with multi-functionality
│ ├── bge_reranker.py # BGE-Reranker cross-encoder wrapper
│ └── losses.py # Loss functions (InfoNCE, ListwiseCE, etc.)
│
├── training/ # Training logic
│ ├── __init__.py
│ ├── trainer.py # Base trainer with error recovery
│ ├── m3_trainer.py # BGE-M3 specialized trainer
│ ├── reranker_trainer.py # Reranker specialized trainer
│ └── joint_trainer.py # RocketQAv2-style joint training
│
├── evaluation/ # Evaluation and metrics
│ ├── __init__.py
│ ├── evaluator.py # Evaluation pipeline
│ ├── metrics.py # Metric implementations (MRR, NDCG, etc.)
│ └── reporting.py # Result visualization and reporting
│
├── utils/ # Utilities and helpers
│ ├── __init__.py
│ ├── logging.py # Distributed-aware logging with tqdm
│ ├── checkpoint.py # Atomic checkpoint save/load
│ ├── distributed.py # Distributed training utilities
│ ├── config_loader.py # TOML config loading with precedence
│ ├── memory_utils.py # Memory optimization utilities
│ └── ascend_support.py # Huawei Ascend NPU support
│
├── scripts/ # Entry point scripts
│ ├── __init__.py
│ ├── train_m3.py # Train BGE-M3 embeddings
│ ├── train_reranker.py # Train BGE-Reranker
│ ├── train_joint.py # Joint training (RocketQAv2)
│ ├── evaluate.py # Comprehensive evaluation
│ ├── benchmark.py # Performance benchmarking
│ ├── compare_retriever.py # Retriever comparison
│ ├── compare_reranker.py # Reranker comparison
│ ├── test_installation.py # Environment verification
│ └── setup.py # Initial setup script
│
├── cache/ # Cache directory (auto-created)
│ ├── models/ # Downloaded model cache
│ └── data/ # Preprocessed data cache
│
├── logs/ # Logging directory (auto-created)
│ ├── training/ # Training logs
│ ├── evaluation/ # Evaluation results
│ └── errors.log # Error logs
│
├── output/ # Training outputs (auto-created)
│ ├── bge-m3-finetuned/ # Fine-tuned models
│ ├── bge-reranker-finetuned/ # Fine-tuned rerankers
│ └── checkpoints/ # Training checkpoints
│
└── docs/ # Additional documentation
├── API.md # API reference
├── CONTRIBUTING.md # Contribution guidelines
└── USAGE_GUIDE.md # Detailed usage guide
Module Descriptions
Core Modules
-
config/: Hierarchical configuration system with TOML support and validation- Automatic type checking and value validation
- Clear precedence rules (CLI > Model > Hardware > Training > Defaults)
- Dynamic configuration based on hardware detection
-
data/: Comprehensive data handling- Format detection and conversion (JSONL, JSON, CSV, TSV, XML, DuReader, MS MARCO)
- Streaming support for large datasets
- Built-in augmentation with LLM and local model support
- ANCE-style hard negative mining with FAISS integration
-
models/: Model wrappers and loss functions- BGE-M3: Dense, sparse (SPLADE), and multi-vector (ColBERT) support
- BGE-Reranker: Cross-encoder with knowledge distillation
- Loss functions: InfoNCE (τ=0.02), ListwiseCE, RankNet, LambdaRank
-
training/: Robust training infrastructure- Error recovery with configurable failure thresholds
- Automatic mixed precision (fp16/bf16) based on hardware
- Gradient checkpointing and memory optimization
- Distributed training with automatic backend selection
-
evaluation/: Comprehensive evaluation suite- Metrics: Recall@K, MRR@K, NDCG@K, MAP, Success Rate, Coverage
- Graded relevance support
- Statistical significance testing
- Performance benchmarking
-
utils/: Production-ready utilities- TqdmLoggingHandler for progress tracking
- Atomic checkpoint saving with backup
- Automatic distributed backend detection (NCCL/HCCL/Gloo)
- Full Ascend NPU support module
Entry Points
-
Training Scripts:
train_m3.py: BGE-M3 training with all featurestrain_reranker.py: Reranker training with distillationtrain_joint.py: RocketQAv2-style joint training
-
Evaluation Scripts:
evaluate.py: Full evaluation with all metricsbenchmark.py: Performance profilingcompare_*.py: Side-by-side model comparison
-
Utility Scripts:
test_installation.py: Verify environment setupsetup.py: Download models and prepare environment
Key Features by Module
-
Automatic Hardware Detection
- CUDA capability checking for optimal dtype (fp16/bf16)
- Ascend NPU detection and configuration
- CPU fallback for development
-
Robust Error Handling
- Training continues despite batch failures
- Automatic checkpoint recovery
- Comprehensive error logging
-
Memory Efficiency
- Gradient checkpointing
- Periodic cache clearing
- Streaming data loading
- Mixed precision training
-
Production Features
- Atomic checkpoint saves
- Distributed training support
- Comprehensive logging
- Configuration validation
🔧 Troubleshooting & FAQ
Q: My GPU is not detected or torch reports CUDA version mismatch.
- Ensure you installed torch >=2.7.1 with CUDA 12.8+ support (see installation section).
- Run
nvidia-smito check your driver and CUDA version.
Q: How do I enable Ascend NPU support?
- Install
torch_npuand setdevice_type = "npu"in[hardware]or[ascend]inconfig.toml. - See Ascend documentation for details.
Q: How do config parameters interact?
[hardware]>[training]for overlapping parameters.[m3]/[reranker]>[training]/[hardware]for their models.- See config class docstrings and the config.toml for details.
Q: How do I run tests?
- Install
pytestand runpytest tests/(if tests are present). - Or run
python scripts/test_installation.pyfor a full environment check.
📚 References
- BGE M3-Embedding: arXiv:2402.03216
- RocketQAv2: ACL Anthology
- ANCE: arXiv:2007.00808
- ColBERT: SIGIR 2020
- SPLADE: arXiv:2109.10086
🔮 Roadmap
- Knowledge Graph integration
- Multi-modal retrieval support
- Additional language support
- Distributed training optimization
- Model quantization for deployment
- Benchmarking on standard datasets