1649 lines
51 KiB
Markdown
1649 lines
51 KiB
Markdown
# BGE Fine-tuning for Chinese Book Retrieval
|
|
|
|
A comprehensive, production-ready implementation for fine-tuning BAAI BGE-M3 embedding model and BGE-Reranker cross-encoder model for optimal retrieval performance on Chinese book content, with English query support.
|
|
|
|
---
|
|
|
|
## ⚡️ Requirements & Compatibility
|
|
|
|
- **Python >= 3.10**
|
|
- **CUDA >= 12.8** (for NVIDIA RTX 5090 and similar)
|
|
- **PyTorch >= 2.2.0** (Use nightly build for CUDA 12.8+ support)
|
|
- **See requirements.txt for all dependencies**
|
|
- **Ascend NPU support:** Install `torch_npu` and set `[hardware]`/`[ascend]` config as needed
|
|
|
|
### Hardware Support
|
|
- **NVIDIA GPUs**: RTX 5090, A100, H100 (CUDA 12.8+)
|
|
- **Huawei Ascend NPUs**: Ascend 910, 910B (with torch_npu)
|
|
- **CPU**: Fallback support for development/testing
|
|
|
|
---
|
|
|
|
## 🚀 Features
|
|
|
|
### Core Capabilities
|
|
- **State-of-the-art Techniques**: InfoNCE loss with proper temperature scaling, self-knowledge distillation, ANCE-style hard negative mining, RocketQAv2 joint training
|
|
- **Multi-functionality**: Dense, sparse (SPLADE-style), and multi-vector (ColBERT-style) retrieval
|
|
- **Cross-lingual Support**: Chinese primary with English query compatibility
|
|
- **Platform Compatibility**: NVIDIA CUDA and Huawei Ascend NPU support with automatic detection
|
|
- **Comprehensive Evaluation**: Built-in metrics (Recall@k, MRR@k, NDCG@k, MAP) with additional success rate and coverage metrics
|
|
- **Production Ready**: Robust error handling, atomic checkpointing, comprehensive logging, and distributed training
|
|
|
|
### Recent Enhancements
|
|
- ✅ **Enhanced Error Recovery**: Training continues despite batch failures with configurable thresholds
|
|
- ✅ **Atomic Checkpoint Saving**: Prevents corruption with temporary directories and backup mechanisms
|
|
- ✅ **Improved Memory Management**: Periodic cache clearing and gradient checkpointing for large models
|
|
- ✅ **Configuration Precedence**: Clear hierarchy (CLI > Model-specific > Hardware > Training > Defaults)
|
|
- ✅ **Unified AMP Imports**: Consistent torch.cuda.amp usage across all modules
|
|
- ✅ **Comprehensive Metrics**: Fixed MRR/NDCG implementations with edge case handling
|
|
- ✅ **Ascend NPU Backend**: Complete support module with device detection and optimization
|
|
- ✅ **Enhanced Dataset Validation**: Robust format detection with helpful error messages
|
|
|
|
---
|
|
|
|
## 📋 Technical Overview
|
|
|
|
### BGE-M3 (Embedding Model)
|
|
- **InfoNCE Loss** with temperature τ=0.02 (verified from BGE-M3 paper)
|
|
- **Proper Loss Normalization**: Fixed cross-entropy calculation in listwise loss
|
|
- **Self-Knowledge Distillation** combining dense, sparse, and multi-vector signals
|
|
- **ANCE-style Hard Negative Mining** with asynchronous ANN index updates
|
|
- **Multi-granularity Support** up to 8192 tokens
|
|
- **Unified Training** for all three functionalities
|
|
- **Gradient Checkpointing** for memory-efficient training on large batches
|
|
|
|
### BGE-Reranker (Cross-encoder)
|
|
- **Listwise Cross-Entropy** loss (fixed implementation, not KL divergence)
|
|
- **Dynamic Listwise Distillation** following RocketQAv2 approach
|
|
- **Hard Negative Selection** from retriever results
|
|
- **Knowledge Distillation** from teacher models
|
|
- **Mixed Precision Training** with bfloat16 support for RTX 5090
|
|
|
|
### Joint Training (RocketQAv2)
|
|
- **Mutual Distillation** between retriever and reranker
|
|
- **Dynamic Hard Negative Mining** using evolving models
|
|
- **Hybrid Data Augmentation** strategies
|
|
- **Error-resilient Training Loop** with failure recovery
|
|
|
|
---
|
|
|
|
## 🛠 Installation
|
|
|
|
### Hardware Requirements
|
|
- **For RTX 5090**: Python 3.10+, CUDA >= 12.8, Driver >= 545.x
|
|
- **For Ascend NPU**: Python 3.10+, CANN >= 7.0, torch_npu
|
|
- **Memory**: 24GB+ GPU memory recommended for full model training
|
|
|
|
### Step 1: Install PyTorch (CUDA 12.8+ for RTX 5090)
|
|
```bash
|
|
# Install PyTorch nightly with CUDA 12.8 support
|
|
pip3 install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu128
|
|
```
|
|
|
|
### Step 2: Install Dependencies
|
|
```bash
|
|
pip install -r requirements.txt
|
|
```
|
|
|
|
### Step 3: (Optional) Platform-specific Setup
|
|
|
|
#### For Huawei Ascend NPU:
|
|
```bash
|
|
# Install torch_npu (follow official Huawei documentation)
|
|
pip install torch_npu
|
|
|
|
# Set environment variables
|
|
export ASCEND_RT_VISIBLE_DEVICES=0,1,2,3 # Your NPU devices
|
|
```
|
|
|
|
#### For Distributed Training:
|
|
```bash
|
|
# No additional setup needed - automatic backend selection:
|
|
# - NCCL for NVIDIA GPUs
|
|
# - HCCL for Ascend NPUs
|
|
# - Gloo for CPU
|
|
```
|
|
|
|
### Step 4: Download Models (Optional - auto-downloads if not present)
|
|
```bash
|
|
# Models will auto-download on first use, or manually download:
|
|
huggingface-cli download BAAI/bge-m3 --local-dir ./models/bge-m3
|
|
huggingface-cli download BAAI/bge-reranker-base --local-dir ./models/bge-reranker-base
|
|
```
|
|
|
|
---
|
|
|
|
## 🔧 Configuration
|
|
|
|
### Configuration System Overview
|
|
|
|
The project uses a hierarchical configuration system with clear precedence rules and automatic validation. Configuration is managed through `config.toml` with programmatic overrides via command-line arguments.
|
|
|
|
### Configuration Precedence Order (Highest to Lowest)
|
|
|
|
1. **Command-line arguments** - Override any config file settings
|
|
2. **Model-specific sections** - `[m3]` for BGE-M3, `[reranker]` for BGE-Reranker
|
|
3. **Hardware-specific sections** - `[hardware]`, `[ascend]` for device-specific settings
|
|
4. **Training defaults** - `[training]` section for general defaults
|
|
5. **Dataclass field defaults** - Built-in defaults in code
|
|
|
|
Example precedence in action:
|
|
```toml
|
|
# config.toml
|
|
[training]
|
|
default_batch_size = 4
|
|
|
|
[hardware]
|
|
default_batch_size = 8 # Overrides [training] for hardware optimization
|
|
|
|
[m3]
|
|
per_device_train_batch_size = 16 # Overrides both [hardware] and [training] for M3 model
|
|
```
|
|
|
|
### Main Configuration File (config.toml)
|
|
|
|
```toml
|
|
# API Configuration for augmentation and external services
|
|
[api]
|
|
base_url = "https://api.deepseek.com/v1"
|
|
api_key = "your-api-key-here" # Replace with your actual API key
|
|
model = "deepseek-chat"
|
|
max_retries = 3
|
|
timeout = 30
|
|
temperature = 0.7
|
|
|
|
# General training defaults
|
|
[training]
|
|
default_batch_size = 4
|
|
default_learning_rate = 1e-5
|
|
default_num_epochs = 3
|
|
default_warmup_ratio = 0.1
|
|
default_weight_decay = 0.01
|
|
gradient_accumulation_steps = 4
|
|
max_grad_norm = 1.0
|
|
logging_steps = 100
|
|
save_strategy = "epoch"
|
|
evaluation_strategy = "epoch"
|
|
load_best_model_at_end = true
|
|
metric_for_best_model = "eval_loss"
|
|
greater_is_better = false
|
|
save_total_limit = 3
|
|
fp16 = true
|
|
bf16 = false # Enable for A100/RTX5090
|
|
gradient_checkpointing = false
|
|
dataloader_num_workers = 4
|
|
remove_unused_columns = true
|
|
label_names = ["labels"]
|
|
|
|
# Model paths with automatic download
|
|
[model_paths]
|
|
bge_m3 = "./models/bge-m3"
|
|
bge_reranker = "./models/bge-reranker-base"
|
|
cache_dir = "./cache/models"
|
|
|
|
# Data processing configuration
|
|
[data]
|
|
max_query_length = 64
|
|
max_passage_length = 512
|
|
max_seq_length = 512
|
|
validation_split = 0.1
|
|
random_seed = 42
|
|
preprocessing_num_workers = 4
|
|
overwrite_cache = false
|
|
pad_to_max_length = false
|
|
return_tensors = "pt"
|
|
|
|
# Hardware-specific optimizations
|
|
[hardware]
|
|
device_type = "cuda" # cuda, npu, cpu
|
|
mixed_precision = true
|
|
mixed_precision_dtype = "fp16" # fp16, bf16
|
|
enable_tf32 = true # For Ampere GPUs
|
|
cudnn_benchmark = true
|
|
cudnn_deterministic = false
|
|
per_device_train_batch_size = 8
|
|
per_device_eval_batch_size = 16
|
|
|
|
# Distributed training configuration
|
|
[distributed]
|
|
backend = "nccl" # nccl, gloo, hccl (auto-detected if not set)
|
|
timeout_minutes = 30
|
|
find_unused_parameters = false
|
|
gradient_as_bucket_view = true
|
|
static_graph = false
|
|
|
|
# Huawei Ascend NPU specific settings
|
|
[ascend]
|
|
device_type = "npu"
|
|
backend = "hccl"
|
|
mixed_precision = true
|
|
mixed_precision_dtype = "fp16"
|
|
visible_devices = "0,1,2,3"
|
|
enable_graph_mode = false
|
|
enable_profiling = false
|
|
profiling_start_step = 10
|
|
profiling_stop_step = 20
|
|
|
|
# BGE-M3 specific configuration
|
|
[m3]
|
|
model_name_or_path = "BAAI/bge-m3"
|
|
learning_rate = 1e-5
|
|
per_device_train_batch_size = 16
|
|
per_device_eval_batch_size = 32
|
|
num_train_epochs = 5
|
|
warmup_ratio = 0.1
|
|
temperature = 0.02 # InfoNCE temperature (from paper)
|
|
use_self_distill = true
|
|
self_distill_weight = 1.0
|
|
use_hard_negatives = true
|
|
num_hard_negatives = 15
|
|
negatives_cross_batch = true
|
|
pooling_method = "cls"
|
|
normalize_embeddings = true
|
|
use_sparse = true
|
|
use_colbert = true
|
|
colbert_dim = 32
|
|
|
|
# BGE-Reranker specific configuration
|
|
[reranker]
|
|
model_name_or_path = "BAAI/bge-reranker-base"
|
|
learning_rate = 6e-5
|
|
per_device_train_batch_size = 8
|
|
per_device_eval_batch_size = 16
|
|
num_train_epochs = 3
|
|
warmup_ratio = 0.06
|
|
train_group_size = 16
|
|
loss_type = "listwise_ce" # listwise_ce, ranknet, lambdarank
|
|
use_gradient_checkpointing = true
|
|
temperature = 1.0
|
|
use_distillation = false
|
|
teacher_model_path = "BAAI/bge-reranker-large"
|
|
distillation_weight = 0.5
|
|
|
|
# Data augmentation settings
|
|
[augmentation]
|
|
enable_query_augmentation = false
|
|
enable_passage_augmentation = false
|
|
augmentation_methods = ["paraphrase", "back_translation", "mask_and_fill"]
|
|
num_augmentations_per_sample = 2
|
|
augmentation_temperature = 0.8
|
|
api_augmentation = true
|
|
local_augmentation_model = "google/flan-t5-base"
|
|
|
|
# Hard negative mining configuration
|
|
[hard_negative_mining]
|
|
enable_mining = true
|
|
mining_strategy = "ance" # ance, random, bm25
|
|
num_hard_negatives = 15
|
|
negative_depth = 200 # How deep to search for negatives
|
|
update_interval = 1000 # Steps between mining updates
|
|
margin = 0.1
|
|
use_cross_batch_negatives = true
|
|
dedup_negatives = true
|
|
|
|
# FAISS index settings for mining
|
|
[faiss]
|
|
index_type = "IVF" # Flat, IVF, HNSW
|
|
index_params = {nlist = 100, nprobe = 10}
|
|
use_gpu = true
|
|
normalize_vectors = true
|
|
|
|
# Checkpoint and recovery settings
|
|
[checkpoint]
|
|
save_strategy = "epoch"
|
|
save_steps = 1000
|
|
save_total_limit = 3
|
|
save_on_each_node = false
|
|
resume_from_checkpoint = true
|
|
checkpoint_save_dir = "./checkpoints"
|
|
use_atomic_save = true # Prevents corruption
|
|
keep_backup = true
|
|
|
|
# Logging configuration
|
|
[logging]
|
|
logging_dir = "./logs"
|
|
logging_strategy = "steps"
|
|
logging_steps = 100
|
|
logging_first_step = true
|
|
report_to = ["tensorboard", "wandb"]
|
|
wandb_project = "bge-finetuning"
|
|
wandb_run_name = null # Auto-generated if null
|
|
disable_tqdm = false
|
|
tqdm_mininterval = 1
|
|
|
|
# Evaluation settings
|
|
[evaluation]
|
|
evaluation_strategy = "epoch"
|
|
eval_steps = 500
|
|
eval_batch_size = 32
|
|
metric_for_best_model = "eval_loss"
|
|
load_best_model_at_end = true
|
|
compute_metrics_fn = "accuracy_and_f1"
|
|
greater_is_better = false
|
|
eval_accumulation_steps = null
|
|
|
|
# Memory optimization
|
|
[memory]
|
|
gradient_checkpointing = true
|
|
gradient_checkpointing_kwargs = {use_reentrant = false}
|
|
optim = "adamw_torch" # adamw_torch, adamw_hf, sgd, adafactor
|
|
optim_args = null
|
|
clear_cache_interval = 100 # Steps between cache clearing
|
|
max_memory_mb = null # Max memory usage before clearing
|
|
|
|
# Error handling and recovery
|
|
[error_handling]
|
|
continue_on_error = true
|
|
max_failed_batches = 10
|
|
error_log_file = "./logs/errors.log"
|
|
save_on_error = true
|
|
retry_failed_batches = true
|
|
num_retries = 3
|
|
```
|
|
|
|
### Command-Line Override Examples
|
|
|
|
```bash
|
|
# Override batch size and learning rate
|
|
python scripts/train_m3.py \
|
|
--per_device_train_batch_size 32 \
|
|
--learning_rate 2e-5
|
|
|
|
# Override hardware settings
|
|
python scripts/train_m3.py \
|
|
--device npu \
|
|
--mixed_precision_dtype bf16
|
|
|
|
# Override multiple config sections
|
|
python scripts/train_m3.py \
|
|
--config_file custom_config.toml \
|
|
--temperature 0.05 \
|
|
--num_hard_negatives 20 \
|
|
--gradient_checkpointing
|
|
```
|
|
|
|
### Environment Variable Support
|
|
|
|
```bash
|
|
# Set API key via environment variable
|
|
export BGE_API_KEY="your-api-key"
|
|
|
|
# Override config file location
|
|
export BGE_CONFIG_FILE="/path/to/custom/config.toml"
|
|
|
|
# Set distributed training environment
|
|
export CUDA_VISIBLE_DEVICES="0,1,2,3"
|
|
export WORLD_SIZE=4
|
|
```
|
|
|
|
### Configuration Validation
|
|
|
|
The configuration system automatically validates settings:
|
|
|
|
```python
|
|
# In your code
|
|
from config.base_config import BaseConfig
|
|
|
|
config = BaseConfig.from_toml("config.toml")
|
|
|
|
# Automatic validation includes:
|
|
# - Type checking (int, float, str, bool, list, dict)
|
|
# - Value range validation (e.g., learning_rate > 0)
|
|
# - Compatibility checks (e.g., bf16 requires compatible hardware)
|
|
# - Missing required fields
|
|
# - Deprecated parameter warnings
|
|
```
|
|
|
|
### Dynamic Configuration
|
|
|
|
```python
|
|
# Programmatic configuration updates
|
|
from config.m3_config import M3Config
|
|
|
|
config = M3Config.from_toml("config.toml")
|
|
|
|
# Update based on hardware
|
|
if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8:
|
|
config.mixed_precision_dtype = "bf16"
|
|
config.per_device_train_batch_size *= 2
|
|
|
|
# Update based on dataset size
|
|
if dataset_size > 1000000:
|
|
config.gradient_accumulation_steps = 8
|
|
config.logging_steps = 500
|
|
```
|
|
|
|
### Best Practices
|
|
|
|
1. **Start with defaults**: The provided `config.toml` has sensible defaults
|
|
2. **Hardware-specific tuning**: Use `[hardware]` section for GPU/NPU optimization
|
|
3. **Model-specific settings**: Use `[m3]` or `[reranker]` for model-specific params
|
|
4. **Environment isolation**: Use different config files for dev/staging/prod
|
|
5. **Version control**: Track config.toml changes in git
|
|
6. **Secrets management**: Use environment variables for API keys
|
|
|
|
---
|
|
|
|
## 🚀 Quick Start
|
|
|
|
#### **Step 1: Prepare Your Data**
|
|
|
|
Create your training data in JSONL format. Each line should be a valid JSON object.
|
|
|
|
**For embedding models (BGE-M3):**
|
|
```json
|
|
{"query": "What is machine learning?", "pos": ["Machine learning is..."], "neg": ["The weather is..."]}
|
|
{"query": "How to cook pasta?", "pos": ["Boil water and add pasta..."], "neg": ["Machine learning..."]}
|
|
```
|
|
|
|
**For reranker models:**
|
|
```json
|
|
{"query": "What is AI?", "passage": "AI is artificial intelligence...", "label": 1}
|
|
{"query": "What is AI?", "passage": "The weather today is sunny...", "label": 0}
|
|
```
|
|
|
|
#### **Step 2: Fine-tune Your Model**
|
|
|
|
```bash
|
|
# Basic training with automatic hardware detection
|
|
python scripts/train_m3.py \
|
|
--train_data data/train.jsonl \
|
|
--output_dir ./output/bge-m3-finetuned \
|
|
--num_train_epochs 3 \
|
|
--per_device_train_batch_size 4 \
|
|
--learning_rate 1e-5
|
|
|
|
# Advanced training with all features enabled
|
|
python scripts/train_m3.py \
|
|
--train_data data/train.jsonl \
|
|
--output_dir ./output/bge-m3-advanced \
|
|
--num_train_epochs 5 \
|
|
--per_device_train_batch_size 8 \
|
|
--learning_rate 1e-5 \
|
|
--use_self_distill \
|
|
--use_hard_negatives \
|
|
--gradient_checkpointing \
|
|
--bf16 \
|
|
--save_total_limit 3 \
|
|
--logging_steps 100 \
|
|
--evaluation_strategy "steps" \
|
|
--eval_steps 500 \
|
|
--metric_for_best_model "eval_loss" \
|
|
--load_best_model_at_end
|
|
```
|
|
|
|
### 3. Fine-tune BGE-Reranker (Cross-encoder)
|
|
|
|
```bash
|
|
# Basic reranker training
|
|
python scripts/train_reranker.py \
|
|
--train_data data/train.jsonl \
|
|
--output_dir ./output/bge-reranker-finetuned \
|
|
--learning_rate 6e-5 \
|
|
--train_group_size 16 \
|
|
--loss_type listwise_ce
|
|
|
|
# Advanced reranker training with knowledge distillation
|
|
python scripts/train_reranker.py \
|
|
--train_data data/train.jsonl \
|
|
--output_dir ./output/bge-reranker-advanced \
|
|
--learning_rate 6e-5 \
|
|
--train_group_size 16 \
|
|
--loss_type listwise_ce \
|
|
--use_distillation \
|
|
--teacher_model_path BAAI/bge-reranker-large \
|
|
--bf16 \
|
|
--gradient_checkpointing \
|
|
--save_strategy "epoch" \
|
|
--evaluation_strategy "epoch"
|
|
```
|
|
|
|
### 4. Joint Training (RocketQAv2-style)
|
|
|
|
```bash
|
|
# Joint training with mutual distillation
|
|
python scripts/train_joint.py \
|
|
--train_data data/train.jsonl \
|
|
--output_dir ./output/joint-training \
|
|
--use_dynamic_distillation \
|
|
--joint_loss_weight 0.3 \
|
|
--bf16 \
|
|
--gradient_checkpointing \
|
|
--max_grad_norm 1.0 \
|
|
--dataloader_num_workers 4
|
|
```
|
|
|
|
### 5. Multi-GPU / Distributed Training
|
|
|
|
```bash
|
|
# NVIDIA GPUs (automatic NCCL backend)
|
|
torchrun --nproc_per_node=4 scripts/train_m3.py \
|
|
--train_data data/train.jsonl \
|
|
--output_dir ./output/distributed \
|
|
--per_device_train_batch_size 16 \
|
|
--gradient_accumulation_steps 2 \
|
|
--bf16 \
|
|
--gradient_checkpointing
|
|
|
|
# Huawei Ascend NPUs (automatic HCCL backend)
|
|
# First, configure config.toml [ascend] section, then:
|
|
python scripts/train_m3.py \
|
|
--train_data data/train.jsonl \
|
|
--output_dir ./output/ascend \
|
|
--device npu
|
|
```
|
|
|
|
---
|
|
|
|
## 📊 Data Format
|
|
|
|
### Canonical Format (Required Structure)
|
|
|
|
All data formats must contain these required fields:
|
|
- `query` (string): The search query or question
|
|
- `pos` (list of strings): Positive (relevant) passages
|
|
- `neg` (list of strings): Negative (irrelevant) passages
|
|
|
|
Optional fields for advanced training:
|
|
- `pos_scores` (list of floats): Relevance scores for positives (0.0-1.0)
|
|
- `neg_scores` (list of floats): Relevance scores for negatives (0.0-1.0)
|
|
- `prompt` (string): Custom prompt for encoding
|
|
|
|
### Supported Input Formats
|
|
|
|
Our dataset loader automatically detects and validates the following formats:
|
|
|
|
#### 1. JSONL (Recommended - Most Efficient)
|
|
```jsonl
|
|
{"query": "What is machine learning?", "pos": ["Machine learning is..."], "neg": ["Deep learning is..."]}
|
|
{"query": "深度学习的应用", "pos": ["深度学习在计算机视觉..."], "neg": ["机器学习是..."]}
|
|
```
|
|
|
|
#### 2. JSON (Array Format)
|
|
```json
|
|
[
|
|
{
|
|
"query": "What is BGE?",
|
|
"pos": ["BGE is a family of embedding models..."],
|
|
"neg": ["BERT is a transformer model..."],
|
|
"pos_scores": [1.0],
|
|
"neg_scores": [0.2]
|
|
}
|
|
]
|
|
```
|
|
|
|
#### 3. CSV/TSV
|
|
```csv
|
|
query,pos,neg,pos_scores,neg_scores
|
|
"What is machine learning?","['ML is...', 'Machine learning refers to...']","['DL is...']","[1.0, 0.9]","[0.1]"
|
|
```
|
|
|
|
#### 4. DuReader Format
|
|
```json
|
|
{
|
|
"data": [{
|
|
"question": "什么是机器学习?",
|
|
"answers": ["机器学习是..."],
|
|
"documents": [{
|
|
"content": "深度学习是...",
|
|
"is_relevant": false
|
|
}]
|
|
}]
|
|
}
|
|
```
|
|
|
|
#### 5. MS MARCO Format (TSV)
|
|
```tsv
|
|
query_id query doc_id document relevance
|
|
1001 What is ML? 2001 Machine learning is... 1
|
|
1001 What is ML? 2002 Deep learning is... 0
|
|
```
|
|
|
|
#### 6. XML Format (TREC-style)
|
|
```xml
|
|
<topics>
|
|
<topic>
|
|
<query>Information retrieval basics</query>
|
|
<positive>IR is the process of obtaining...</positive>
|
|
<negative>Database management is...</negative>
|
|
</topic>
|
|
</topics>
|
|
```
|
|
|
|
### Data Validation & Error Handling
|
|
|
|
The dataset loader performs comprehensive validation:
|
|
- **Format Detection**: Automatic format inference with helpful error messages
|
|
- **Field Validation**: Ensures all required fields are present and properly typed
|
|
- **Content Validation**: Checks for empty queries/passages, invalid scores
|
|
- **Encoding Validation**: Handles various text encodings (UTF-8, GB2312, etc.)
|
|
- **Size Validation**: Warns about extremely long sequences
|
|
|
|
Example validation output:
|
|
```
|
|
Dataset validation completed:
|
|
✓ Format: JSONL
|
|
✓ Total samples: 10,000
|
|
✓ Valid samples: 9,998
|
|
✗ Invalid samples: 2
|
|
- Line 1523: Empty query field
|
|
- Line 7891: Missing 'neg' field
|
|
✓ Encoding: UTF-8
|
|
⚠ Warnings:
|
|
- 15 samples have passages > 8192 tokens (will be truncated)
|
|
```
|
|
|
|
### Converting Data Formats
|
|
|
|
Use the built-in preprocessor to convert between formats:
|
|
|
|
```python
|
|
from data.preprocessing import DataPreprocessor
|
|
|
|
preprocessor = DataPreprocessor()
|
|
|
|
# Convert any format to canonical JSONL
|
|
preprocessor.convert_to_canonical(
|
|
input_file="data/msmarco.tsv",
|
|
output_file="data/train.jsonl",
|
|
input_format="msmarco" # or "auto" for automatic detection
|
|
)
|
|
|
|
# Batch conversion with validation
|
|
preprocessor.batch_convert(
|
|
input_files=["data/*.csv", "data/*.json"],
|
|
output_dir="data/processed/",
|
|
validate=True,
|
|
max_workers=4
|
|
)
|
|
```
|
|
|
|
### Working with Book/Document Collections
|
|
|
|
For training on books or large documents:
|
|
|
|
1. **Split documents into passages**:
|
|
```python
|
|
from data.preprocessing import DocumentSplitter
|
|
|
|
splitter = DocumentSplitter(
|
|
chunk_size=512,
|
|
overlap=50,
|
|
split_by="sentence" # or "paragraph", "sliding_window"
|
|
)
|
|
|
|
passages = splitter.split_document("path/to/book.txt")
|
|
```
|
|
|
|
2. **Generate training pairs**:
|
|
```python
|
|
from data.preprocessing import TrainingPairGenerator
|
|
|
|
generator = TrainingPairGenerator()
|
|
|
|
# Generate query-passage pairs with hard negatives
|
|
training_data = generator.generate_from_passages(
|
|
passages=passages,
|
|
queries_per_passage=2,
|
|
hard_negative_ratio=0.7,
|
|
use_bm25=True
|
|
)
|
|
|
|
# Save in canonical format
|
|
generator.save_jsonl(training_data, "data/book_train.jsonl")
|
|
```
|
|
|
|
3. **Use data optimization**:
|
|
```bash
|
|
# Optimize data for BGE-M3 training
|
|
python -m data.optimization bge-m3 data/book_train.jsonl \
|
|
--output_dir ./optimized_data \
|
|
--hard_negative_ratio 0.8 \
|
|
--augment_queries
|
|
```
|
|
|
|
### Memory-Efficient Data Loading
|
|
|
|
For large datasets, use streaming mode:
|
|
|
|
```python
|
|
# In your training script
|
|
dataset = ContrastiveDataset(
|
|
data_file="data/large_dataset.jsonl",
|
|
tokenizer=tokenizer,
|
|
streaming=True, # Enable streaming
|
|
buffer_size=10000 # Buffer size for shuffling
|
|
)
|
|
```
|
|
|
|
### 📖 Training with Book `.txt` Files
|
|
|
|
- **You cannot train directly on a raw `.txt` book file.**
|
|
- **Convert your book to a supported format:**
|
|
- Split the book into passages (e.g., by chapter or paragraph).
|
|
- Create queries and annotate which passages are positive (relevant) and negative (irrelevant) for each query.
|
|
- Example:
|
|
```jsonl
|
|
{"query": "What happens in Chapter 1?", "pos": ["Chapter 1: ..."], "neg": ["Chapter 2: ..."]}
|
|
```
|
|
- Use the provided `DataPreprocessor` to convert CSV/TSV/JSON/XML to JSONL.
|
|
|
|
**Note:** All formats must be convertible to the canonical JSONL structure above. Use the provided `DataPreprocessor` to convert between formats.
|
|
|
|
---
|
|
|
|
## ⚙️ Configuration
|
|
|
|
### Parameter Precedence
|
|
- `[hardware]` overrides `[training]` for overlapping parameters (e.g., batch size).
|
|
- `[m3]` and `[reranker]` override `[training]`/`[hardware]` for their respective models.
|
|
- For Ascend NPU, use the `[ascend]` section and set `device_type = "npu"`.
|
|
|
|
### Hardware Compatibility
|
|
- **NVIDIA RTX 5090:** Use torch >=2.7.1 and CUDA 12.8+ (see above).
|
|
- **Huawei Ascend NPU:** Install `torch_npu`, set `[hardware]` or `[ascend]` config, and ensure your environment matches Ascend requirements.
|
|
|
|
### API Configuration (config.toml)
|
|
|
|
```toml
|
|
[api]
|
|
base_url = "https://api.deepseek.com/v1"
|
|
api_key = "your-api-key-here"
|
|
model = "deepseek-chat"
|
|
|
|
[training]
|
|
default_batch_size = 4
|
|
default_learning_rate = 1e-5
|
|
|
|
[model_paths]
|
|
bge_m3 = "./models/bge-m3"
|
|
bge_reranker = "./models/bge-reranker-base"
|
|
```
|
|
|
|
### Configuration Parameter Reference
|
|
|
|
#### [api]
|
|
| Parameter | Type | Default | Description |
|
|
| -------------- | ------- | ------------------------ | ------------------------------------------- |
|
|
| base_url | str | "https://api.deepseek..."| API endpoint for augmentation |
|
|
| api_key | str | "your-api-key-here" | API key for external services |
|
|
| model | str | "deepseek-chat" | Model name for API |
|
|
| max_retries | int | 3 | Max API retries |
|
|
| timeout | int | 30 | API timeout (seconds) |
|
|
| temperature | float | 0.7 | API sampling temperature |
|
|
|
|
#### [training]
|
|
| Parameter | Type | Default | Description |
|
|
|--------------------------|---------|---------|---------------------------------------------|
|
|
| default_batch_size | int | 4 | Default batch size for training |
|
|
| default_learning_rate | float | 1e-5 | Default learning rate |
|
|
| default_num_epochs | int | 3 | Default number of epochs |
|
|
| default_warmup_ratio | float | 0.1 | Warmup steps ratio |
|
|
| default_weight_decay | float | 0.01 | Weight decay for optimizer |
|
|
| gradient_accumulation_steps | int | 4 | Steps to accumulate gradients |
|
|
|
|
#### [model_paths]
|
|
| Parameter | Type | Default | Description |
|
|
| -------------- | ------- | ------------------------ | ------------------------------------------- |
|
|
| bge_m3 | str | "./models/bge-m3" | Path to BGE-M3 model |
|
|
| bge_reranker | str | "./models/bge-reranker-base" | Path to BGE-Reranker model |
|
|
| cache_dir | str | "./cache/models" | Model cache directory |
|
|
|
|
#### [data]
|
|
| Parameter | Type | Default | Description |
|
|
|---------------------|---------|---------|---------------------------------------------|
|
|
| max_query_length | int | 64 | Max query length |
|
|
| max_passage_length | int | 512 | Max passage length |
|
|
| max_seq_length | int | 512 | Max sequence length |
|
|
| validation_split | float | 0.1 | Validation split ratio |
|
|
| random_seed | int | 42 | Random seed for reproducibility |
|
|
|
|
#### [augmentation]
|
|
| Parameter | Type | Default | Description |
|
|
|-----------------------------|---------|--------------------------|---------------------------------------------|
|
|
| enable_query_augmentation | bool | false | Enable query augmentation |
|
|
| enable_passage_augmentation | bool | false | Enable passage augmentation |
|
|
| augmentation_methods | list | ["paraphrase", "back_translation"] | Augmentation methods to use |
|
|
| num_augmentations_per_sample| int | 2 | Number of augmentations per sample |
|
|
| augmentation_temperature | float | 0.8 | Temperature for augmentation |
|
|
| api_first | bool | true | Prefer LLM API for all augmentation if enabled |
|
|
|
|
#### [hard_negative_mining]
|
|
| Parameter | Type | Default | Description |
|
|
|---------------------|---------|---------|---------------------------------------------|
|
|
| enable_mining | bool | true | Enable hard negative mining |
|
|
| num_hard_negatives | int | 15 | Number of hard negatives to mine |
|
|
| negative_range_min | int | 10 | Min negative index for mining |
|
|
| negative_range_max | int | 100 | Max negative index for mining |
|
|
| margin | float | 0.1 | Margin for ANCE mining |
|
|
| index_refresh_interval | int | 1000 | Steps between index refreshes |
|
|
| index_type | str | "IVF" | FAISS index type |
|
|
| nlist | int | 100 | Number of clusters for IVF |
|
|
|
|
#### [m3] and [reranker]
|
|
- All model-specific parameters are now documented in the config classes and code docstrings. See `config/m3_config.py` and `config/reranker_config.py` for full details.
|
|
|
|
---
|
|
|
|
## ⚡️ Distributed Training
|
|
|
|
### Overview
|
|
|
|
The project supports distributed training across multiple GPUs/NPUs with automatic backend detection, error recovery, and optimized communication strategies. The distributed training system is fully configuration-driven and handles both data-parallel and model-parallel strategies.
|
|
|
|
### Automatic Backend Detection
|
|
|
|
The system automatically selects the optimal backend based on your hardware:
|
|
- **NVIDIA GPUs**: NCCL (fastest for GPU-to-GPU communication)
|
|
- **Huawei Ascend NPUs**: HCCL (optimized for Ascend architecture)
|
|
- **CPU or Mixed**: Gloo (fallback for heterogeneous setups)
|
|
|
|
### Basic Usage
|
|
|
|
#### Single-Node Multi-GPU (NVIDIA)
|
|
|
|
```bash
|
|
# Automatic detection and setup
|
|
torchrun --nproc_per_node=4 scripts/train_m3.py \
|
|
--train_data data/train.jsonl \
|
|
--output_dir ./output/distributed
|
|
|
|
# With specific GPU selection
|
|
CUDA_VISIBLE_DEVICES=0,1,2,3 torchrun --nproc_per_node=4 scripts/train_m3.py \
|
|
--train_data data/train.jsonl \
|
|
--output_dir ./output/distributed \
|
|
--gradient_accumulation_steps 2
|
|
```
|
|
|
|
#### Multi-Node Training (NVIDIA)
|
|
|
|
```bash
|
|
# On node 0 (master)
|
|
torchrun --nproc_per_node=4 --nnodes=2 --node_rank=0 \
|
|
--master_addr=192.168.1.1 --master_port=29500 \
|
|
scripts/train_m3.py --train_data data/train.jsonl
|
|
|
|
# On node 1
|
|
torchrun --nproc_per_node=4 --nnodes=2 --node_rank=1 \
|
|
--master_addr=192.168.1.1 --master_port=29500 \
|
|
scripts/train_m3.py --train_data data/train.jsonl
|
|
```
|
|
|
|
#### Huawei Ascend NPU Training
|
|
|
|
```bash
|
|
# Single-node multi-NPU
|
|
python scripts/train_m3.py \
|
|
--train_data data/train.jsonl \
|
|
--output_dir ./output/ascend \
|
|
--device npu
|
|
|
|
# With specific NPU selection
|
|
export ASCEND_RT_VISIBLE_DEVICES=0,1,2,3
|
|
python scripts/train_m3.py \
|
|
--train_data data/train.jsonl \
|
|
--output_dir ./output/ascend
|
|
```
|
|
|
|
### Configuration Options
|
|
|
|
```toml
|
|
# config.toml
|
|
[distributed]
|
|
backend = "auto" # auto, nccl, gloo, hccl
|
|
timeout_minutes = 30
|
|
find_unused_parameters = false
|
|
gradient_as_bucket_view = true
|
|
broadcast_buffers = true
|
|
bucket_cap_mb = 25
|
|
use_distributed_optimizer = false
|
|
|
|
# Performance optimizations
|
|
[distributed.optimization]
|
|
overlap_comm = true
|
|
gradient_compression = false
|
|
compression_type = "none" # none, powerSGD, fp16
|
|
all_reduce_fusion = true
|
|
fusion_threshold_mb = 64
|
|
|
|
# Fault tolerance
|
|
[distributed.fault_tolerance]
|
|
enable_checkpointing = true
|
|
checkpoint_interval = 1000
|
|
restart_on_failure = true
|
|
max_restart_attempts = 3
|
|
heartbeat_interval = 60
|
|
```
|
|
|
|
### Advanced Features
|
|
|
|
#### 1. Gradient Accumulation with Distributed Training
|
|
|
|
```python
|
|
# Effective batch size = per_device_batch * gradient_accumulation * num_gpus
|
|
# Example: 8 * 4 * 4 = 128 effective batch size
|
|
python scripts/train_m3.py \
|
|
--per_device_train_batch_size 8 \
|
|
--gradient_accumulation_steps 4 \
|
|
--distributed
|
|
```
|
|
|
|
#### 2. Mixed Precision with Distributed
|
|
|
|
```bash
|
|
# Automatic mixed precision for faster training
|
|
torchrun --nproc_per_node=4 scripts/train_m3.py \
|
|
--bf16 \
|
|
--tf32 # Additional optimization for Ampere GPUs
|
|
```
|
|
|
|
#### 3. DeepSpeed Integration
|
|
|
|
```bash
|
|
# With DeepSpeed for extreme scale
|
|
deepspeed scripts/train_m3.py \
|
|
--deepspeed deepspeed_config.json \
|
|
--train_data data/train.jsonl
|
|
```
|
|
|
|
Example DeepSpeed config:
|
|
```json
|
|
{
|
|
"train_batch_size": "auto",
|
|
"gradient_accumulation_steps": "auto",
|
|
"fp16": {
|
|
"enabled": "auto",
|
|
"loss_scale": 0,
|
|
"loss_scale_window": 1000
|
|
},
|
|
"zero_optimization": {
|
|
"stage": 2,
|
|
"offload_optimizer": {
|
|
"device": "cpu",
|
|
"pin_memory": true
|
|
},
|
|
"allgather_partitions": true,
|
|
"allgather_bucket_size": 2e8,
|
|
"overlap_comm": true,
|
|
"reduce_scatter": true,
|
|
"reduce_bucket_size": 2e8,
|
|
"contiguous_gradients": true
|
|
}
|
|
}
|
|
```
|
|
|
|
#### 4. Fault Tolerance and Recovery
|
|
|
|
The distributed training system includes robust error handling:
|
|
|
|
```python
|
|
# Training automatically recovers from:
|
|
# - Temporary network failures
|
|
# - GPU memory errors (with retry)
|
|
# - Node failures (with checkpoint recovery)
|
|
# - Communication timeouts
|
|
```
|
|
|
|
#### 5. Performance Monitoring
|
|
|
|
```bash
|
|
# Enable distributed profiling
|
|
python scripts/train_m3.py \
|
|
--enable_profiling \
|
|
--profiling_start_step 100 \
|
|
--profiling_end_step 200
|
|
|
|
# Monitor with tensorboard
|
|
tensorboard --logdir ./logs --bind_all
|
|
```
|
|
|
|
### Best Practices
|
|
|
|
1. **Data Loading**
|
|
```python
|
|
# Use distributed sampler (automatic)
|
|
# Set num_workers based on CPUs
|
|
--dataloader_num_workers 4
|
|
```
|
|
|
|
2. **Batch Size Scaling**
|
|
```python
|
|
# Scale learning rate with batch size
|
|
# LR = base_lr * (total_batch_size / base_batch_size)
|
|
--learning_rate 1e-5 * (num_gpus * batch_size / 16)
|
|
```
|
|
|
|
3. **Communication Optimization**
|
|
```python
|
|
# Reduce communication overhead
|
|
--gradient_accumulation_steps 4
|
|
--broadcast_buffers false # If not needed
|
|
```
|
|
|
|
4. **Memory Optimization**
|
|
```python
|
|
# For large models
|
|
--gradient_checkpointing
|
|
--zero_stage 2 # With DeepSpeed
|
|
```
|
|
|
|
### Troubleshooting
|
|
|
|
#### NCCL Errors
|
|
```bash
|
|
# Debug NCCL issues
|
|
export NCCL_DEBUG=INFO
|
|
export NCCL_DEBUG_SUBSYS=ALL
|
|
|
|
# Common fixes
|
|
export NCCL_IB_DISABLE=1 # Disable InfiniBand
|
|
export NCCL_P2P_DISABLE=1 # Disable P2P
|
|
```
|
|
|
|
#### Hanging Training
|
|
```bash
|
|
# Set timeout for debugging
|
|
export NCCL_TIMEOUT=300 # 5 minutes
|
|
export NCCL_ASYNC_ERROR_HANDLING=1
|
|
```
|
|
|
|
#### Memory Issues
|
|
```python
|
|
# Clear cache periodically
|
|
--clear_cache_interval 100
|
|
|
|
# Use gradient checkpointing
|
|
--gradient_checkpointing
|
|
```
|
|
|
|
### Performance Benchmarks
|
|
|
|
Typical speedup with distributed training:
|
|
- 2 GPUs: 1.8-1.9x speedup
|
|
- 4 GPUs: 3.5-3.8x speedup
|
|
- 8 GPUs: 6.5-7.5x speedup
|
|
|
|
Factors affecting performance:
|
|
- Model size (larger models scale better)
|
|
- Batch size (larger batches reduce communication overhead)
|
|
- Network bandwidth (especially for multi-node)
|
|
- Data loading speed
|
|
|
|
---
|
|
|
|
## 📊 Evaluation Metrics
|
|
|
|
### Overview
|
|
|
|
The evaluation system provides comprehensive metrics for both retrieval and reranking tasks, with support for graded relevance, statistical significance testing, and performance benchmarking.
|
|
|
|
### Core Metrics
|
|
|
|
#### Retrieval Metrics
|
|
|
|
1. **Recall@K** - Fraction of relevant documents retrieved in top K
|
|
```python
|
|
recall = |retrieved_relevant| / |total_relevant|
|
|
```
|
|
|
|
2. **Precision@K** - Fraction of retrieved documents that are relevant
|
|
```python
|
|
precision = |retrieved_relevant| / K
|
|
```
|
|
|
|
3. **Mean Average Precision (MAP)** - Average precision at each relevant position
|
|
```python
|
|
MAP = mean([AP(q) for q in queries])
|
|
where AP = sum([P@k * rel(k) for k in positions]) / |relevant|
|
|
```
|
|
|
|
4. **Mean Reciprocal Rank (MRR@K)** - Average reciprocal of first relevant result rank
|
|
```python
|
|
MRR = mean([1/rank_of_first_relevant for q in queries])
|
|
```
|
|
|
|
5. **Normalized Discounted Cumulative Gain (NDCG@K)** - Position-weighted relevance score
|
|
```python
|
|
DCG@K = sum([rel(i) / log2(i + 1) for i in 1..K])
|
|
NDCG@K = DCG@K / IDCG@K
|
|
```
|
|
|
|
6. **Success Rate@K** - Fraction of queries with at least one relevant result in top K
|
|
```python
|
|
success_rate = |queries_with_relevant_in_topK| / |total_queries|
|
|
```
|
|
|
|
7. **Coverage@K** - Fraction of all relevant documents found across all queries
|
|
```python
|
|
coverage = |unique_relevant_retrieved| / |total_unique_relevant|
|
|
```
|
|
|
|
#### Reranking Metrics
|
|
|
|
1. **Accuracy@1** - Top-1 ranking accuracy
|
|
2. **MRR** - Mean reciprocal rank for reranked results
|
|
3. **NDCG@K** - Ranking quality with graded relevance
|
|
4. **ERR (Expected Reciprocal Rank)** - Cascade model metric
|
|
|
|
### Usage Examples
|
|
|
|
#### Basic Evaluation - **NEW STREAMLINED SYSTEM** 🎯
|
|
|
|
```bash
|
|
# Quick validation (5 minutes)
|
|
python scripts/validate.py quick \
|
|
--retriever_model ./output/bge-m3-finetuned \
|
|
--reranker_model ./output/bge-reranker-finetuned
|
|
|
|
# Comprehensive evaluation (30 minutes)
|
|
python scripts/validate.py comprehensive \
|
|
--retriever_model ./output/bge-m3-finetuned \
|
|
--reranker_model ./output/bge-reranker-finetuned \
|
|
--test_data_dir ./data/test
|
|
|
|
# Compare with baselines
|
|
python scripts/validate.py compare \
|
|
--retriever_model ./output/bge-m3-finetuned \
|
|
--reranker_model ./output/bge-reranker-finetuned \
|
|
--retriever_data data/test_retriever.jsonl \
|
|
--reranker_data data/test_reranker.jsonl
|
|
|
|
# Complete validation suite
|
|
python scripts/validate.py all \
|
|
--retriever_model ./output/bge-m3-finetuned \
|
|
--reranker_model ./output/bge-reranker-finetuned \
|
|
--test_data_dir ./data/test
|
|
```
|
|
|
|
#### Advanced Evaluation with Graded Relevance
|
|
|
|
```python
|
|
# Data format with relevance scores
|
|
{
|
|
"query": "machine learning basics",
|
|
"documents": [
|
|
{"text": "Introduction to ML", "relevance": 2}, # Highly relevant
|
|
{"text": "Deep learning guide", "relevance": 1}, # Somewhat relevant
|
|
{"text": "Database systems", "relevance": 0} # Not relevant
|
|
]
|
|
}
|
|
```
|
|
|
|
```bash
|
|
# Comprehensive evaluation with detailed metrics (includes NDCG, graded relevance)
|
|
python scripts/validate.py comprehensive \
|
|
--retriever_model ./output/bge-m3-finetuned \
|
|
--reranker_model ./output/bge-reranker-finetuned \
|
|
--test_data_dir data/test \
|
|
--batch_size 32
|
|
```
|
|
|
|
#### Statistical Significance Testing
|
|
|
|
```bash
|
|
# Compare models with statistical analysis
|
|
python scripts/compare_models.py \
|
|
--model_type retriever \
|
|
--baseline_model BAAI/bge-m3 \
|
|
--finetuned_model ./output/bge-m3-finetuned \
|
|
--data_path data/test.jsonl
|
|
--num_bootstrap_samples 1000 \
|
|
--confidence_level 0.95
|
|
|
|
# Output includes:
|
|
# - Metric comparisons with confidence intervals
|
|
# - Statistical significance (p-values)
|
|
# - Effect sizes (Cohen's d)
|
|
```
|
|
|
|
### Metric Computation Details
|
|
|
|
#### Handling Edge Cases
|
|
|
|
1. **Empty Results**
|
|
- Recall: 0 if no relevant documents exist, else 0
|
|
- Precision: 0 for empty results
|
|
- MRR: 0 if no relevant document found
|
|
- NDCG: 0 for empty results
|
|
|
|
2. **Tied Scores**
|
|
- Uses stable sorting to maintain consistency
|
|
- Averages metrics across tied positions
|
|
|
|
3. **Missing Relevance Labels**
|
|
- Binary relevance: treats unlabeled as non-relevant
|
|
- Graded relevance: requires explicit labels
|
|
|
|
#### Efficiency Optimizations
|
|
|
|
```python
|
|
# Batch evaluation for large datasets
|
|
evaluator = Evaluator(
|
|
model=model,
|
|
batch_size=128,
|
|
use_fp16=True,
|
|
device="cuda"
|
|
)
|
|
|
|
# Streaming evaluation for memory efficiency
|
|
results = evaluator.evaluate_streaming(
|
|
data_path="large_dataset.jsonl",
|
|
chunk_size=10000
|
|
)
|
|
```
|
|
|
|
### Performance Benchmarking
|
|
|
|
```bash
|
|
# Benchmark inference speed
|
|
python scripts/benchmark.py \
|
|
--model_type retriever \
|
|
--model_path ./output/bge-m3-finetuned \
|
|
--data_path data/test.jsonl \
|
|
--batch_sizes 1 8 16 32 64 \
|
|
--num_warmup_steps 10 \
|
|
--num_benchmark_steps 100
|
|
|
|
# Output includes:
|
|
# - Throughput (queries/second)
|
|
# - Latency percentiles (p50, p90, p95, p99)
|
|
# - Memory usage
|
|
# - Hardware utilization
|
|
```
|
|
|
|
### Custom Metrics
|
|
|
|
```python
|
|
from evaluation.metrics import BaseMetric
|
|
|
|
class CustomMetric(BaseMetric):
|
|
def compute(self, predictions, references, k=10):
|
|
# Your custom metric logic
|
|
return score
|
|
|
|
# Register custom metric
|
|
from evaluation.evaluator import Evaluator
|
|
evaluator = Evaluator(model=model)
|
|
evaluator.add_metric("custom_metric", CustomMetric())
|
|
```
|
|
|
|
### Visualization and Reporting
|
|
|
|
```python
|
|
# Generate evaluation report
|
|
from evaluation.reporting import EvaluationReport
|
|
|
|
report = EvaluationReport(results)
|
|
report.generate_html("evaluation_report.html")
|
|
report.generate_latex("evaluation_report.tex")
|
|
report.plot_metrics("plots/")
|
|
|
|
# Example plots:
|
|
# - Recall@K curves
|
|
# - Precision-Recall curves
|
|
# - NDCG@K comparison
|
|
# - Query-level performance distribution
|
|
```
|
|
|
|
### Best Practices
|
|
|
|
1. **Use Appropriate K Values**
|
|
- Small K (1, 5, 10) for user-facing metrics
|
|
- Large K (50, 100) for reranking scenarios
|
|
|
|
2. **Consider Query Types**
|
|
- Navigational queries: Focus on MRR, Success@1
|
|
- Informational queries: Focus on NDCG, Recall@K
|
|
|
|
3. **Statistical Rigor**
|
|
- Always use test sets with 1000+ queries
|
|
- Report confidence intervals
|
|
- Use paired tests for model comparison
|
|
|
|
4. **Graded Relevance**
|
|
- Use NDCG for nuanced evaluation
|
|
- Define clear relevance guidelines
|
|
- Consider query-specific relevance scales
|
|
|
|
---
|
|
|
|
## 📈 Model Evaluation & Benchmarking - **NEW STREAMLINED SYSTEM** 🎯
|
|
|
|
> **⚡ SIMPLIFIED**: The old confusing multiple validation scripts have been replaced with **one simple command**!
|
|
|
|
### **Quick Start - Single Command for Everything:**
|
|
|
|
```bash
|
|
# Quick validation (5 minutes) - Did my training work?
|
|
python scripts/validate.py quick \
|
|
--retriever_model ./output/bge_m3/final_model \
|
|
--reranker_model ./output/reranker/final_model
|
|
|
|
# Compare with baselines (15 minutes) - How much did I improve?
|
|
python scripts/validate.py compare \
|
|
--retriever_model ./output/bge_m3/final_model \
|
|
--reranker_model ./output/reranker/final_model \
|
|
--retriever_data ./test_data/m3_test.jsonl \
|
|
--reranker_data ./test_data/reranker_test.jsonl
|
|
|
|
# Complete validation suite (1 hour) - Production ready?
|
|
python scripts/validate.py all \
|
|
--retriever_model ./output/bge_m3/final_model \
|
|
--reranker_model ./output/reranker/final_model \
|
|
--test_data_dir ./test_data
|
|
```
|
|
|
|
### **Validation Modes:**
|
|
|
|
| Mode | Time | Purpose | Command |
|
|
|------|------|---------|---------|
|
|
| `quick` | 5 min | Sanity check | `python scripts/validate.py quick --retriever_model ... --reranker_model ...` |
|
|
| `compare` | 15 min | Baseline comparison | `python scripts/validate.py compare --retriever_model ... --reranker_data ...` |
|
|
| `comprehensive` | 30 min | Detailed metrics | `python scripts/validate.py comprehensive --test_data_dir ...` |
|
|
| `benchmark` | 10 min | Performance only | `python scripts/validate.py benchmark --retriever_model ...` |
|
|
| `all` | 1 hour | Complete suite | `python scripts/validate.py all --test_data_dir ...` |
|
|
|
|
### **Advanced Usage:**
|
|
|
|
```bash
|
|
# Unified model comparison
|
|
python scripts/compare_models.py \
|
|
--model_type both \
|
|
--finetuned_retriever ./output/bge_m3/final_model \
|
|
--finetuned_reranker ./output/reranker/final_model \
|
|
--retriever_data ./test_data/m3_test.jsonl \
|
|
--reranker_data ./test_data/reranker_test.jsonl
|
|
|
|
# Performance benchmarking
|
|
python scripts/benchmark.py \
|
|
--model_type retriever \
|
|
--model_path ./output/bge_m3/final_model \
|
|
--data_path ./test_data/m3_test.jsonl
|
|
```
|
|
|
|
**📋 See [README_VALIDATION.md](./README_VALIDATION.md) for complete documentation**
|
|
|
|
---
|
|
|
|
## 🖥️ Ascend NPU Section
|
|
|
|
- **Install**: `pip install torch_npu` (see [Ascend docs](https://gitee.com/ascend/pytorch))
|
|
- **Edit `config.toml`:**
|
|
```toml
|
|
[ascend]
|
|
device_type = "npu"
|
|
backend = "hccl"
|
|
mixed_precision = true
|
|
visible_devices = "0,1,2,3"
|
|
```
|
|
- **Launch training:**
|
|
```bash
|
|
python3.10 scripts/train_m3.py --train_data data/train.jsonl --output_dir ./output/npu
|
|
```
|
|
- **Distributed:** Use Ascend's distributed launcher or set `backend = "hccl"` in `[ascend]`.
|
|
- **All device/backend selection is config-driven.**
|
|
|
|
---
|
|
|
|
## 🚀 Performance Optimization
|
|
|
|
### Multi-GPU Training
|
|
|
|
```bash
|
|
torchrun --nproc_per_node=2 scripts/train_m3.py \
|
|
--train_data data/train.jsonl \
|
|
--output_dir ./output/distributed \
|
|
--fp16 \
|
|
--gradient_checkpointing
|
|
```
|
|
|
|
### Memory Optimization
|
|
|
|
- **Gradient Checkpointing**: Reduces memory usage
|
|
- **Mixed Precision**: FP16 training for RTX 5090
|
|
- **Gradient Accumulation**: Effective larger batch sizes
|
|
- **DeepSpeed Integration**: For large-scale training
|
|
|
|
### Huawei Ascend Compatibility
|
|
|
|
```python
|
|
# For Ascend deployment
|
|
import torch_npu
|
|
device = torch.device("npu")
|
|
model.to(device)
|
|
```
|
|
|
|
---
|
|
|
|
## 📝 **Recent Changes & Improvements**
|
|
|
|
### ✅ **Simplified Data Pipeline (Latest)**
|
|
- **Removed unnecessary PKL conversion pipeline**: The `data/optimization.py` module has been removed
|
|
- Performance benefits came from pre-tokenization, not PKL format
|
|
- PKL files caused memory issues and added complexity without real benefits
|
|
- **Added in-memory caching option**: For small datasets (< 100k samples), enable `cache_in_memory=True`
|
|
- **Kept JSONL as primary format**: Industry standard, human-readable, and efficient
|
|
- **All core functionality intact**: Training, evaluation, and model features unchanged
|
|
|
|
### 🚀 **Performance Tips**
|
|
|
|
For better training performance:
|
|
|
|
```python
|
|
# Enable in-memory caching for small datasets
|
|
dataset = BGEM3Dataset(
|
|
data_path="train.jsonl",
|
|
tokenizer=tokenizer,
|
|
cache_in_memory=True, # Pre-tokenize and cache in memory
|
|
max_cache_size=100000 # Maximum samples to cache
|
|
)
|
|
```
|
|
|
|
For more optimization options, run: `python -m data.optimization --help`
|
|
|
|
---
|
|
|
|
## 🏗 Architecture
|
|
|
|
```
|
|
bge_finetune/
|
|
├── config.toml # Main configuration file
|
|
├── requirements.txt # Python dependencies
|
|
├── README.md # Project documentation
|
|
├── __init__.py # Package initialization
|
|
│
|
|
├── config/ # Configuration management
|
|
│ ├── __init__.py
|
|
│ ├── base_config.py # Base configuration class with validation
|
|
│ ├── m3_config.py # BGE-M3 specific configuration
|
|
│ └── reranker_config.py # BGE-Reranker specific configuration
|
|
│
|
|
├── data/ # Data processing and management
|
|
│ ├── __init__.py
|
|
│ ├── dataset.py # Dataset classes (Contrastive, Reranker, Joint)
|
|
│ ├── preprocessing.py # Format conversion and validation
|
|
│ ├── augmentation.py # Query/passage augmentation strategies
|
|
│ ├── hard_negative_mining.py # ANCE-style hard negative mining
|
|
│ └── optimization.py # Data optimization utilities
|
|
│
|
|
├── models/ # Model implementations
|
|
│ ├── __init__.py
|
|
│ ├── bge_m3.py # BGE-M3 wrapper with multi-functionality
|
|
│ ├── bge_reranker.py # BGE-Reranker cross-encoder wrapper
|
|
│ └── losses.py # Loss functions (InfoNCE, ListwiseCE, etc.)
|
|
│
|
|
├── training/ # Training logic
|
|
│ ├── __init__.py
|
|
│ ├── trainer.py # Base trainer with error recovery
|
|
│ ├── m3_trainer.py # BGE-M3 specialized trainer
|
|
│ ├── reranker_trainer.py # Reranker specialized trainer
|
|
│ └── joint_trainer.py # RocketQAv2-style joint training
|
|
│
|
|
├── evaluation/ # Evaluation and metrics
|
|
│ ├── __init__.py
|
|
│ ├── evaluator.py # Evaluation pipeline
|
|
│ ├── metrics.py # Metric implementations (MRR, NDCG, etc.)
|
|
│ └── reporting.py # Result visualization and reporting
|
|
│
|
|
├── utils/ # Utilities and helpers
|
|
│ ├── __init__.py
|
|
│ ├── logging.py # Distributed-aware logging with tqdm
|
|
│ ├── checkpoint.py # Atomic checkpoint save/load
|
|
│ ├── distributed.py # Distributed training utilities
|
|
│ ├── config_loader.py # TOML config loading with precedence
|
|
│ ├── memory_utils.py # Memory optimization utilities
|
|
│ └── ascend_support.py # Huawei Ascend NPU support
|
|
│
|
|
├── scripts/ # Entry point scripts
|
|
│ ├── __init__.py
|
|
│ ├── train_m3.py # Train BGE-M3 embeddings
|
|
│ ├── train_reranker.py # Train BGE-Reranker
|
|
│ ├── train_joint.py # Joint training (RocketQAv2)
|
|
│ ├── validate.py # **NEW: Single validation entry point**
|
|
│ ├── compare_models.py # **NEW: Unified model comparison**
|
|
│ ├── benchmark.py # Performance benchmarking
|
|
│ ├── comprehensive_validation.py # Detailed validation suite
|
|
│ ├── quick_validation.py # Quick validation checks
|
|
│ ├── test_installation.py # Environment verification
|
|
│ └── setup.py # Initial setup script
|
|
│
|
|
├── cache/ # Cache directory (auto-created)
|
|
│ ├── models/ # Downloaded model cache
|
|
│ └── data/ # Preprocessed data cache
|
|
│
|
|
├── logs/ # Logging directory (auto-created)
|
|
│ ├── training/ # Training logs
|
|
│ ├── evaluation/ # Evaluation results
|
|
│ └── errors.log # Error logs
|
|
│
|
|
├── output/ # Training outputs (auto-created)
|
|
│ ├── bge-m3-finetuned/ # Fine-tuned models
|
|
│ ├── bge-reranker-finetuned/ # Fine-tuned rerankers
|
|
│ └── checkpoints/ # Training checkpoints
|
|
│
|
|
└── docs/ # Additional documentation
|
|
├── API.md # API reference
|
|
├── CONTRIBUTING.md # Contribution guidelines
|
|
└── USAGE_GUIDE.md # Detailed usage guide
|
|
```
|
|
|
|
### Module Descriptions
|
|
|
|
#### Core Modules
|
|
|
|
- **`config/`**: Hierarchical configuration system with TOML support and validation
|
|
- Automatic type checking and value validation
|
|
- Clear precedence rules (CLI > Model > Hardware > Training > Defaults)
|
|
- Dynamic configuration based on hardware detection
|
|
|
|
- **`data/`**: Comprehensive data handling
|
|
- Format detection and conversion (JSONL, JSON, CSV, TSV, XML, DuReader, MS MARCO)
|
|
- Streaming support for large datasets
|
|
- Built-in augmentation with LLM and local model support
|
|
- ANCE-style hard negative mining with FAISS integration
|
|
|
|
- **`models/`**: Model wrappers and loss functions
|
|
- BGE-M3: Dense, sparse (SPLADE), and multi-vector (ColBERT) support
|
|
- BGE-Reranker: Cross-encoder with knowledge distillation
|
|
- Loss functions: InfoNCE (τ=0.02), ListwiseCE, RankNet, LambdaRank
|
|
|
|
- **`training/`**: Robust training infrastructure
|
|
- Error recovery with configurable failure thresholds
|
|
- Automatic mixed precision (fp16/bf16) based on hardware
|
|
- Gradient checkpointing and memory optimization
|
|
- Distributed training with automatic backend selection
|
|
|
|
- **`evaluation/`**: Comprehensive evaluation suite
|
|
- Metrics: Recall@K, MRR@K, NDCG@K, MAP, Success Rate, Coverage
|
|
- Graded relevance support
|
|
- Statistical significance testing
|
|
- Performance benchmarking
|
|
|
|
- **`utils/`**: Production-ready utilities
|
|
- TqdmLoggingHandler for progress tracking
|
|
- Atomic checkpoint saving with backup
|
|
- Automatic distributed backend detection (NCCL/HCCL/Gloo)
|
|
- Full Ascend NPU support module
|
|
|
|
#### Entry Points
|
|
|
|
- **Training Scripts**:
|
|
- `train_m3.py`: BGE-M3 training with all features
|
|
- `train_reranker.py`: Reranker training with distillation
|
|
- `train_joint.py`: RocketQAv2-style joint training
|
|
|
|
- **Evaluation Scripts**:
|
|
- `evaluate.py`: Full evaluation with all metrics
|
|
- `benchmark.py`: Performance profiling
|
|
- `compare_*.py`: Side-by-side model comparison
|
|
|
|
- **Utility Scripts**:
|
|
- `test_installation.py`: Verify environment setup
|
|
- `setup.py`: Download models and prepare environment
|
|
|
|
### Key Features by Module
|
|
|
|
1. **Automatic Hardware Detection**
|
|
- CUDA capability checking for optimal dtype (fp16/bf16)
|
|
- Ascend NPU detection and configuration
|
|
- CPU fallback for development
|
|
|
|
2. **Robust Error Handling**
|
|
- Training continues despite batch failures
|
|
- Automatic checkpoint recovery
|
|
- Comprehensive error logging
|
|
|
|
3. **Memory Efficiency**
|
|
- Gradient checkpointing
|
|
- Periodic cache clearing
|
|
- Streaming data loading
|
|
- Mixed precision training
|
|
|
|
4. **Production Features**
|
|
- Atomic checkpoint saves
|
|
- Distributed training support
|
|
- Comprehensive logging
|
|
- Configuration validation
|
|
|
|
---
|
|
|
|
## 🔧 Troubleshooting & FAQ
|
|
|
|
**Q: My GPU is not detected or torch reports CUDA version mismatch.**
|
|
- Ensure you installed torch >=2.7.1 with CUDA 12.8+ support (see installation section).
|
|
- Run `nvidia-smi` to check your driver and CUDA version.
|
|
|
|
**Q: How do I enable Ascend NPU support?**
|
|
- Install `torch_npu` and set `device_type = "npu"` in `[hardware]` or `[ascend]` in `config.toml`.
|
|
- See [Ascend documentation](https://gitee.com/ascend/pytorch) for details.
|
|
|
|
**Q: How do config parameters interact?**
|
|
- `[hardware]` > `[training]` for overlapping parameters.
|
|
- `[m3]`/`[reranker]` > `[training]`/`[hardware]` for their models.
|
|
- See config class docstrings and the config.toml for details.
|
|
|
|
**Q: How do I run tests?**
|
|
- Install `pytest` and run `pytest tests/` (if tests are present).
|
|
- Or run `python scripts/test_installation.py` for a full environment check.
|
|
|
|
---
|
|
|
|
## 📚 References
|
|
|
|
1. **BGE M3-Embedding**: [arXiv:2402.03216](https://arxiv.org/abs/2402.03216)
|
|
2. **RocketQAv2**: [ACL Anthology](https://aclanthology.org/2021.emnlp-main.224/)
|
|
3. **ANCE**: [arXiv:2007.00808](https://arxiv.org/abs/2007.00808)
|
|
4. **ColBERT**: [SIGIR 2020](https://dl.acm.org/doi/10.1145/3397271.3401075)
|
|
5. **SPLADE**: [arXiv:2109.10086](https://arxiv.org/abs/2109.10086)
|
|
|
|
---
|
|
|
|
## 🔮 Roadmap
|
|
|
|
- [ ] Knowledge Graph integration
|
|
- [ ] Multi-modal retrieval support
|
|
- [ ] Additional language support
|
|
- [ ] Distributed training optimization
|
|
- [ ] Model quantization for deployment
|
|
- [ ] Benchmarking on standard datasets
|