bge_finetune/readme.md
2025-07-22 16:55:25 +08:00

1628 lines
51 KiB
Markdown

# BGE Fine-tuning for Chinese Book Retrieval
A comprehensive, production-ready implementation for fine-tuning BAAI BGE-M3 embedding model and BGE-Reranker cross-encoder model for optimal retrieval performance on Chinese book content, with English query support.
---
## ⚡️ Requirements & Compatibility
- **Python >= 3.10**
- **CUDA >= 12.8** (for NVIDIA RTX 5090 and similar)
- **PyTorch >= 2.2.0** (Use nightly build for CUDA 12.8+ support)
- **See requirements.txt for all dependencies**
- **Ascend NPU support:** Install `torch_npu` and set `[hardware]`/`[ascend]` config as needed
### Hardware Support
- **NVIDIA GPUs**: RTX 5090, A100, H100 (CUDA 12.8+)
- **Huawei Ascend NPUs**: Ascend 910, 910B (with torch_npu)
- **CPU**: Fallback support for development/testing
---
## 🚀 Features
### Core Capabilities
- **State-of-the-art Techniques**: InfoNCE loss with proper temperature scaling, self-knowledge distillation, ANCE-style hard negative mining, RocketQAv2 joint training
- **Multi-functionality**: Dense, sparse (SPLADE-style), and multi-vector (ColBERT-style) retrieval
- **Cross-lingual Support**: Chinese primary with English query compatibility
- **Platform Compatibility**: NVIDIA CUDA and Huawei Ascend NPU support with automatic detection
- **Comprehensive Evaluation**: Built-in metrics (Recall@k, MRR@k, NDCG@k, MAP) with additional success rate and coverage metrics
- **Production Ready**: Robust error handling, atomic checkpointing, comprehensive logging, and distributed training
### Recent Enhancements
-**Enhanced Error Recovery**: Training continues despite batch failures with configurable thresholds
-**Atomic Checkpoint Saving**: Prevents corruption with temporary directories and backup mechanisms
-**Improved Memory Management**: Periodic cache clearing and gradient checkpointing for large models
-**Configuration Precedence**: Clear hierarchy (CLI > Model-specific > Hardware > Training > Defaults)
-**Unified AMP Imports**: Consistent torch.cuda.amp usage across all modules
-**Comprehensive Metrics**: Fixed MRR/NDCG implementations with edge case handling
-**Ascend NPU Backend**: Complete support module with device detection and optimization
-**Enhanced Dataset Validation**: Robust format detection with helpful error messages
---
## 📋 Technical Overview
### BGE-M3 (Embedding Model)
- **InfoNCE Loss** with temperature τ=0.02 (verified from BGE-M3 paper)
- **Proper Loss Normalization**: Fixed cross-entropy calculation in listwise loss
- **Self-Knowledge Distillation** combining dense, sparse, and multi-vector signals
- **ANCE-style Hard Negative Mining** with asynchronous ANN index updates
- **Multi-granularity Support** up to 8192 tokens
- **Unified Training** for all three functionalities
- **Gradient Checkpointing** for memory-efficient training on large batches
### BGE-Reranker (Cross-encoder)
- **Listwise Cross-Entropy** loss (fixed implementation, not KL divergence)
- **Dynamic Listwise Distillation** following RocketQAv2 approach
- **Hard Negative Selection** from retriever results
- **Knowledge Distillation** from teacher models
- **Mixed Precision Training** with bfloat16 support for RTX 5090
### Joint Training (RocketQAv2)
- **Mutual Distillation** between retriever and reranker
- **Dynamic Hard Negative Mining** using evolving models
- **Hybrid Data Augmentation** strategies
- **Error-resilient Training Loop** with failure recovery
---
## 🛠 Installation
### Hardware Requirements
- **For RTX 5090**: Python 3.10+, CUDA >= 12.8, Driver >= 545.x
- **For Ascend NPU**: Python 3.10+, CANN >= 7.0, torch_npu
- **Memory**: 24GB+ GPU memory recommended for full model training
### Step 1: Install PyTorch (CUDA 12.8+ for RTX 5090)
```bash
# Install PyTorch nightly with CUDA 12.8 support
pip3 install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu128
```
### Step 2: Install Dependencies
```bash
pip install -r requirements.txt
```
### Step 3: (Optional) Platform-specific Setup
#### For Huawei Ascend NPU:
```bash
# Install torch_npu (follow official Huawei documentation)
pip install torch_npu
# Set environment variables
export ASCEND_RT_VISIBLE_DEVICES=0,1,2,3 # Your NPU devices
```
#### For Distributed Training:
```bash
# No additional setup needed - automatic backend selection:
# - NCCL for NVIDIA GPUs
# - HCCL for Ascend NPUs
# - Gloo for CPU
```
### Step 4: Download Models (Optional - auto-downloads if not present)
```bash
# Models will auto-download on first use, or manually download:
huggingface-cli download BAAI/bge-m3 --local-dir ./models/bge-m3
huggingface-cli download BAAI/bge-reranker-base --local-dir ./models/bge-reranker-base
```
---
## 🔧 Configuration
### Configuration System Overview
The project uses a hierarchical configuration system with clear precedence rules and automatic validation. Configuration is managed through `config.toml` with programmatic overrides via command-line arguments.
### Configuration Precedence Order (Highest to Lowest)
1. **Command-line arguments** - Override any config file settings
2. **Model-specific sections** - `[m3]` for BGE-M3, `[reranker]` for BGE-Reranker
3. **Hardware-specific sections** - `[hardware]`, `[ascend]` for device-specific settings
4. **Training defaults** - `[training]` section for general defaults
5. **Dataclass field defaults** - Built-in defaults in code
Example precedence in action:
```toml
# config.toml
[training]
default_batch_size = 4
[hardware]
default_batch_size = 8 # Overrides [training] for hardware optimization
[m3]
per_device_train_batch_size = 16 # Overrides both [hardware] and [training] for M3 model
```
### Main Configuration File (config.toml)
```toml
# API Configuration for augmentation and external services
[api]
base_url = "https://api.deepseek.com/v1"
api_key = "your-api-key-here" # Replace with your actual API key
model = "deepseek-chat"
max_retries = 3
timeout = 30
temperature = 0.7
# General training defaults
[training]
default_batch_size = 4
default_learning_rate = 1e-5
default_num_epochs = 3
default_warmup_ratio = 0.1
default_weight_decay = 0.01
gradient_accumulation_steps = 4
max_grad_norm = 1.0
logging_steps = 100
save_strategy = "epoch"
evaluation_strategy = "epoch"
load_best_model_at_end = true
metric_for_best_model = "eval_loss"
greater_is_better = false
save_total_limit = 3
fp16 = true
bf16 = false # Enable for A100/RTX5090
gradient_checkpointing = false
dataloader_num_workers = 4
remove_unused_columns = true
label_names = ["labels"]
# Model paths with automatic download
[model_paths]
bge_m3 = "./models/bge-m3"
bge_reranker = "./models/bge-reranker-base"
cache_dir = "./cache/models"
# Data processing configuration
[data]
max_query_length = 64
max_passage_length = 512
max_seq_length = 512
validation_split = 0.1
random_seed = 42
preprocessing_num_workers = 4
overwrite_cache = false
pad_to_max_length = false
return_tensors = "pt"
# Hardware-specific optimizations
[hardware]
device_type = "cuda" # cuda, npu, cpu
mixed_precision = true
mixed_precision_dtype = "fp16" # fp16, bf16
enable_tf32 = true # For Ampere GPUs
cudnn_benchmark = true
cudnn_deterministic = false
per_device_train_batch_size = 8
per_device_eval_batch_size = 16
# Distributed training configuration
[distributed]
backend = "nccl" # nccl, gloo, hccl (auto-detected if not set)
timeout_minutes = 30
find_unused_parameters = false
gradient_as_bucket_view = true
static_graph = false
# Huawei Ascend NPU specific settings
[ascend]
device_type = "npu"
backend = "hccl"
mixed_precision = true
mixed_precision_dtype = "fp16"
visible_devices = "0,1,2,3"
enable_graph_mode = false
enable_profiling = false
profiling_start_step = 10
profiling_stop_step = 20
# BGE-M3 specific configuration
[m3]
model_name_or_path = "BAAI/bge-m3"
learning_rate = 1e-5
per_device_train_batch_size = 16
per_device_eval_batch_size = 32
num_train_epochs = 5
warmup_ratio = 0.1
temperature = 0.02 # InfoNCE temperature (from paper)
use_self_distill = true
self_distill_weight = 1.0
use_hard_negatives = true
num_hard_negatives = 15
negatives_cross_batch = true
pooling_method = "cls"
normalize_embeddings = true
use_sparse = true
use_colbert = true
colbert_dim = 32
# BGE-Reranker specific configuration
[reranker]
model_name_or_path = "BAAI/bge-reranker-base"
learning_rate = 6e-5
per_device_train_batch_size = 8
per_device_eval_batch_size = 16
num_train_epochs = 3
warmup_ratio = 0.06
train_group_size = 16
loss_type = "listwise_ce" # listwise_ce, ranknet, lambdarank
use_gradient_checkpointing = true
temperature = 1.0
use_distillation = false
teacher_model_path = "BAAI/bge-reranker-large"
distillation_weight = 0.5
# Data augmentation settings
[augmentation]
enable_query_augmentation = false
enable_passage_augmentation = false
augmentation_methods = ["paraphrase", "back_translation", "mask_and_fill"]
num_augmentations_per_sample = 2
augmentation_temperature = 0.8
api_augmentation = true
local_augmentation_model = "google/flan-t5-base"
# Hard negative mining configuration
[hard_negative_mining]
enable_mining = true
mining_strategy = "ance" # ance, random, bm25
num_hard_negatives = 15
negative_depth = 200 # How deep to search for negatives
update_interval = 1000 # Steps between mining updates
margin = 0.1
use_cross_batch_negatives = true
dedup_negatives = true
# FAISS index settings for mining
[faiss]
index_type = "IVF" # Flat, IVF, HNSW
index_params = {nlist = 100, nprobe = 10}
use_gpu = true
normalize_vectors = true
# Checkpoint and recovery settings
[checkpoint]
save_strategy = "epoch"
save_steps = 1000
save_total_limit = 3
save_on_each_node = false
resume_from_checkpoint = true
checkpoint_save_dir = "./checkpoints"
use_atomic_save = true # Prevents corruption
keep_backup = true
# Logging configuration
[logging]
logging_dir = "./logs"
logging_strategy = "steps"
logging_steps = 100
logging_first_step = true
report_to = ["tensorboard", "wandb"]
wandb_project = "bge-finetuning"
wandb_run_name = null # Auto-generated if null
disable_tqdm = false
tqdm_mininterval = 1
# Evaluation settings
[evaluation]
evaluation_strategy = "epoch"
eval_steps = 500
eval_batch_size = 32
metric_for_best_model = "eval_loss"
load_best_model_at_end = true
compute_metrics_fn = "accuracy_and_f1"
greater_is_better = false
eval_accumulation_steps = null
# Memory optimization
[memory]
gradient_checkpointing = true
gradient_checkpointing_kwargs = {use_reentrant = false}
optim = "adamw_torch" # adamw_torch, adamw_hf, sgd, adafactor
optim_args = null
clear_cache_interval = 100 # Steps between cache clearing
max_memory_mb = null # Max memory usage before clearing
# Error handling and recovery
[error_handling]
continue_on_error = true
max_failed_batches = 10
error_log_file = "./logs/errors.log"
save_on_error = true
retry_failed_batches = true
num_retries = 3
```
### Command-Line Override Examples
```bash
# Override batch size and learning rate
python scripts/train_m3.py \
--per_device_train_batch_size 32 \
--learning_rate 2e-5
# Override hardware settings
python scripts/train_m3.py \
--device npu \
--mixed_precision_dtype bf16
# Override multiple config sections
python scripts/train_m3.py \
--config_file custom_config.toml \
--temperature 0.05 \
--num_hard_negatives 20 \
--gradient_checkpointing
```
### Environment Variable Support
```bash
# Set API key via environment variable
export BGE_API_KEY="your-api-key"
# Override config file location
export BGE_CONFIG_FILE="/path/to/custom/config.toml"
# Set distributed training environment
export CUDA_VISIBLE_DEVICES="0,1,2,3"
export WORLD_SIZE=4
```
### Configuration Validation
The configuration system automatically validates settings:
```python
# In your code
from config.base_config import BaseConfig
config = BaseConfig.from_toml("config.toml")
# Automatic validation includes:
# - Type checking (int, float, str, bool, list, dict)
# - Value range validation (e.g., learning_rate > 0)
# - Compatibility checks (e.g., bf16 requires compatible hardware)
# - Missing required fields
# - Deprecated parameter warnings
```
### Dynamic Configuration
```python
# Programmatic configuration updates
from config.m3_config import M3Config
config = M3Config.from_toml("config.toml")
# Update based on hardware
if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8:
config.mixed_precision_dtype = "bf16"
config.per_device_train_batch_size *= 2
# Update based on dataset size
if dataset_size > 1000000:
config.gradient_accumulation_steps = 8
config.logging_steps = 500
```
### Best Practices
1. **Start with defaults**: The provided `config.toml` has sensible defaults
2. **Hardware-specific tuning**: Use `[hardware]` section for GPU/NPU optimization
3. **Model-specific settings**: Use `[m3]` or `[reranker]` for model-specific params
4. **Environment isolation**: Use different config files for dev/staging/prod
5. **Version control**: Track config.toml changes in git
6. **Secrets management**: Use environment variables for API keys
---
## 🚀 Quick Start
#### **Step 1: Prepare Your Data**
Create your training data in JSONL format. Each line should be a valid JSON object.
**For embedding models (BGE-M3):**
```json
{"query": "What is machine learning?", "pos": ["Machine learning is..."], "neg": ["The weather is..."]}
{"query": "How to cook pasta?", "pos": ["Boil water and add pasta..."], "neg": ["Machine learning..."]}
```
**For reranker models:**
```json
{"query": "What is AI?", "passage": "AI is artificial intelligence...", "label": 1}
{"query": "What is AI?", "passage": "The weather today is sunny...", "label": 0}
```
#### **Step 2: Fine-tune Your Model**
```bash
# Basic training with automatic hardware detection
python scripts/train_m3.py \
--train_data data/train.jsonl \
--output_dir ./output/bge-m3-finetuned \
--num_train_epochs 3 \
--per_device_train_batch_size 4 \
--learning_rate 1e-5
# Advanced training with all features enabled
python scripts/train_m3.py \
--train_data data/train.jsonl \
--output_dir ./output/bge-m3-advanced \
--num_train_epochs 5 \
--per_device_train_batch_size 8 \
--learning_rate 1e-5 \
--use_self_distill \
--use_hard_negatives \
--gradient_checkpointing \
--bf16 \
--save_total_limit 3 \
--logging_steps 100 \
--evaluation_strategy "steps" \
--eval_steps 500 \
--metric_for_best_model "eval_loss" \
--load_best_model_at_end
```
### 3. Fine-tune BGE-Reranker (Cross-encoder)
```bash
# Basic reranker training
python scripts/train_reranker.py \
--train_data data/train.jsonl \
--output_dir ./output/bge-reranker-finetuned \
--learning_rate 6e-5 \
--train_group_size 16 \
--loss_type listwise_ce
# Advanced reranker training with knowledge distillation
python scripts/train_reranker.py \
--train_data data/train.jsonl \
--output_dir ./output/bge-reranker-advanced \
--learning_rate 6e-5 \
--train_group_size 16 \
--loss_type listwise_ce \
--use_distillation \
--teacher_model_path BAAI/bge-reranker-large \
--bf16 \
--gradient_checkpointing \
--save_strategy "epoch" \
--evaluation_strategy "epoch"
```
### 4. Joint Training (RocketQAv2-style)
```bash
# Joint training with mutual distillation
python scripts/train_joint.py \
--train_data data/train.jsonl \
--output_dir ./output/joint-training \
--use_dynamic_distillation \
--joint_loss_weight 0.3 \
--bf16 \
--gradient_checkpointing \
--max_grad_norm 1.0 \
--dataloader_num_workers 4
```
### 5. Multi-GPU / Distributed Training
```bash
# NVIDIA GPUs (automatic NCCL backend)
torchrun --nproc_per_node=4 scripts/train_m3.py \
--train_data data/train.jsonl \
--output_dir ./output/distributed \
--per_device_train_batch_size 16 \
--gradient_accumulation_steps 2 \
--bf16 \
--gradient_checkpointing
# Huawei Ascend NPUs (automatic HCCL backend)
# First, configure config.toml [ascend] section, then:
python scripts/train_m3.py \
--train_data data/train.jsonl \
--output_dir ./output/ascend \
--device npu
```
---
## 📊 Data Format
### Canonical Format (Required Structure)
All data formats must contain these required fields:
- `query` (string): The search query or question
- `pos` (list of strings): Positive (relevant) passages
- `neg` (list of strings): Negative (irrelevant) passages
Optional fields for advanced training:
- `pos_scores` (list of floats): Relevance scores for positives (0.0-1.0)
- `neg_scores` (list of floats): Relevance scores for negatives (0.0-1.0)
- `prompt` (string): Custom prompt for encoding
### Supported Input Formats
Our dataset loader automatically detects and validates the following formats:
#### 1. JSONL (Recommended - Most Efficient)
```jsonl
{"query": "What is machine learning?", "pos": ["Machine learning is..."], "neg": ["Deep learning is..."]}
{"query": "深度学习的应用", "pos": ["深度学习在计算机视觉..."], "neg": ["机器学习是..."]}
```
#### 2. JSON (Array Format)
```json
[
{
"query": "What is BGE?",
"pos": ["BGE is a family of embedding models..."],
"neg": ["BERT is a transformer model..."],
"pos_scores": [1.0],
"neg_scores": [0.2]
}
]
```
#### 3. CSV/TSV
```csv
query,pos,neg,pos_scores,neg_scores
"What is machine learning?","['ML is...', 'Machine learning refers to...']","['DL is...']","[1.0, 0.9]","[0.1]"
```
#### 4. DuReader Format
```json
{
"data": [{
"question": "什么是机器学习?",
"answers": ["机器学习是..."],
"documents": [{
"content": "深度学习是...",
"is_relevant": false
}]
}]
}
```
#### 5. MS MARCO Format (TSV)
```tsv
query_id query doc_id document relevance
1001 What is ML? 2001 Machine learning is... 1
1001 What is ML? 2002 Deep learning is... 0
```
#### 6. XML Format (TREC-style)
```xml
<topics>
<topic>
<query>Information retrieval basics</query>
<positive>IR is the process of obtaining...</positive>
<negative>Database management is...</negative>
</topic>
</topics>
```
### Data Validation & Error Handling
The dataset loader performs comprehensive validation:
- **Format Detection**: Automatic format inference with helpful error messages
- **Field Validation**: Ensures all required fields are present and properly typed
- **Content Validation**: Checks for empty queries/passages, invalid scores
- **Encoding Validation**: Handles various text encodings (UTF-8, GB2312, etc.)
- **Size Validation**: Warns about extremely long sequences
Example validation output:
```
Dataset validation completed:
✓ Format: JSONL
✓ Total samples: 10,000
✓ Valid samples: 9,998
✗ Invalid samples: 2
- Line 1523: Empty query field
- Line 7891: Missing 'neg' field
✓ Encoding: UTF-8
⚠ Warnings:
- 15 samples have passages > 8192 tokens (will be truncated)
```
### Converting Data Formats
Use the built-in preprocessor to convert between formats:
```python
from data.preprocessing import DataPreprocessor
preprocessor = DataPreprocessor()
# Convert any format to canonical JSONL
preprocessor.convert_to_canonical(
input_file="data/msmarco.tsv",
output_file="data/train.jsonl",
input_format="msmarco" # or "auto" for automatic detection
)
# Batch conversion with validation
preprocessor.batch_convert(
input_files=["data/*.csv", "data/*.json"],
output_dir="data/processed/",
validate=True,
max_workers=4
)
```
### Working with Book/Document Collections
For training on books or large documents:
1. **Split documents into passages**:
```python
from data.preprocessing import DocumentSplitter
splitter = DocumentSplitter(
chunk_size=512,
overlap=50,
split_by="sentence" # or "paragraph", "sliding_window"
)
passages = splitter.split_document("path/to/book.txt")
```
2. **Generate training pairs**:
```python
from data.preprocessing import TrainingPairGenerator
generator = TrainingPairGenerator()
# Generate query-passage pairs with hard negatives
training_data = generator.generate_from_passages(
passages=passages,
queries_per_passage=2,
hard_negative_ratio=0.7,
use_bm25=True
)
# Save in canonical format
generator.save_jsonl(training_data, "data/book_train.jsonl")
```
3. **Use data optimization**:
```bash
# Optimize data for BGE-M3 training
python -m data.optimization bge-m3 data/book_train.jsonl \
--output_dir ./optimized_data \
--hard_negative_ratio 0.8 \
--augment_queries
```
### Memory-Efficient Data Loading
For large datasets, use streaming mode:
```python
# In your training script
dataset = ContrastiveDataset(
data_file="data/large_dataset.jsonl",
tokenizer=tokenizer,
streaming=True, # Enable streaming
buffer_size=10000 # Buffer size for shuffling
)
```
### 📖 Training with Book `.txt` Files
- **You cannot train directly on a raw `.txt` book file.**
- **Convert your book to a supported format:**
- Split the book into passages (e.g., by chapter or paragraph).
- Create queries and annotate which passages are positive (relevant) and negative (irrelevant) for each query.
- Example:
```jsonl
{"query": "What happens in Chapter 1?", "pos": ["Chapter 1: ..."], "neg": ["Chapter 2: ..."]}
```
- Use the provided `DataPreprocessor` to convert CSV/TSV/JSON/XML to JSONL.
**Note:** All formats must be convertible to the canonical JSONL structure above. Use the provided `DataPreprocessor` to convert between formats.
---
## ⚙️ Configuration
### Parameter Precedence
- `[hardware]` overrides `[training]` for overlapping parameters (e.g., batch size).
- `[m3]` and `[reranker]` override `[training]`/`[hardware]` for their respective models.
- For Ascend NPU, use the `[ascend]` section and set `device_type = "npu"`.
### Hardware Compatibility
- **NVIDIA RTX 5090:** Use torch >=2.7.1 and CUDA 12.8+ (see above).
- **Huawei Ascend NPU:** Install `torch_npu`, set `[hardware]` or `[ascend]` config, and ensure your environment matches Ascend requirements.
### API Configuration (config.toml)
```toml
[api]
base_url = "https://api.deepseek.com/v1"
api_key = "your-api-key-here"
model = "deepseek-chat"
[training]
default_batch_size = 4
default_learning_rate = 1e-5
[model_paths]
bge_m3 = "./models/bge-m3"
bge_reranker = "./models/bge-reranker-base"
```
### Configuration Parameter Reference
#### [api]
| Parameter | Type | Default | Description |
| -------------- | ------- | ------------------------ | ------------------------------------------- |
| base_url | str | "https://api.deepseek..."| API endpoint for augmentation |
| api_key | str | "your-api-key-here" | API key for external services |
| model | str | "deepseek-chat" | Model name for API |
| max_retries | int | 3 | Max API retries |
| timeout | int | 30 | API timeout (seconds) |
| temperature | float | 0.7 | API sampling temperature |
#### [training]
| Parameter | Type | Default | Description |
|--------------------------|---------|---------|---------------------------------------------|
| default_batch_size | int | 4 | Default batch size for training |
| default_learning_rate | float | 1e-5 | Default learning rate |
| default_num_epochs | int | 3 | Default number of epochs |
| default_warmup_ratio | float | 0.1 | Warmup steps ratio |
| default_weight_decay | float | 0.01 | Weight decay for optimizer |
| gradient_accumulation_steps | int | 4 | Steps to accumulate gradients |
#### [model_paths]
| Parameter | Type | Default | Description |
| -------------- | ------- | ------------------------ | ------------------------------------------- |
| bge_m3 | str | "./models/bge-m3" | Path to BGE-M3 model |
| bge_reranker | str | "./models/bge-reranker-base" | Path to BGE-Reranker model |
| cache_dir | str | "./cache/models" | Model cache directory |
#### [data]
| Parameter | Type | Default | Description |
|---------------------|---------|---------|---------------------------------------------|
| max_query_length | int | 64 | Max query length |
| max_passage_length | int | 512 | Max passage length |
| max_seq_length | int | 512 | Max sequence length |
| validation_split | float | 0.1 | Validation split ratio |
| random_seed | int | 42 | Random seed for reproducibility |
#### [augmentation]
| Parameter | Type | Default | Description |
|-----------------------------|---------|--------------------------|---------------------------------------------|
| enable_query_augmentation | bool | false | Enable query augmentation |
| enable_passage_augmentation | bool | false | Enable passage augmentation |
| augmentation_methods | list | ["paraphrase", "back_translation"] | Augmentation methods to use |
| num_augmentations_per_sample| int | 2 | Number of augmentations per sample |
| augmentation_temperature | float | 0.8 | Temperature for augmentation |
| api_first | bool | true | Prefer LLM API for all augmentation if enabled |
#### [hard_negative_mining]
| Parameter | Type | Default | Description |
|---------------------|---------|---------|---------------------------------------------|
| enable_mining | bool | true | Enable hard negative mining |
| num_hard_negatives | int | 15 | Number of hard negatives to mine |
| negative_range_min | int | 10 | Min negative index for mining |
| negative_range_max | int | 100 | Max negative index for mining |
| margin | float | 0.1 | Margin for ANCE mining |
| index_refresh_interval | int | 1000 | Steps between index refreshes |
| index_type | str | "IVF" | FAISS index type |
| nlist | int | 100 | Number of clusters for IVF |
#### [m3] and [reranker]
- All model-specific parameters are now documented in the config classes and code docstrings. See `config/m3_config.py` and `config/reranker_config.py` for full details.
---
## ⚡️ Distributed Training
### Overview
The project supports distributed training across multiple GPUs/NPUs with automatic backend detection, error recovery, and optimized communication strategies. The distributed training system is fully configuration-driven and handles both data-parallel and model-parallel strategies.
### Automatic Backend Detection
The system automatically selects the optimal backend based on your hardware:
- **NVIDIA GPUs**: NCCL (fastest for GPU-to-GPU communication)
- **Huawei Ascend NPUs**: HCCL (optimized for Ascend architecture)
- **CPU or Mixed**: Gloo (fallback for heterogeneous setups)
### Basic Usage
#### Single-Node Multi-GPU (NVIDIA)
```bash
# Automatic detection and setup
torchrun --nproc_per_node=4 scripts/train_m3.py \
--train_data data/train.jsonl \
--output_dir ./output/distributed
# With specific GPU selection
CUDA_VISIBLE_DEVICES=0,1,2,3 torchrun --nproc_per_node=4 scripts/train_m3.py \
--train_data data/train.jsonl \
--output_dir ./output/distributed \
--gradient_accumulation_steps 2
```
#### Multi-Node Training (NVIDIA)
```bash
# On node 0 (master)
torchrun --nproc_per_node=4 --nnodes=2 --node_rank=0 \
--master_addr=192.168.1.1 --master_port=29500 \
scripts/train_m3.py --train_data data/train.jsonl
# On node 1
torchrun --nproc_per_node=4 --nnodes=2 --node_rank=1 \
--master_addr=192.168.1.1 --master_port=29500 \
scripts/train_m3.py --train_data data/train.jsonl
```
#### Huawei Ascend NPU Training
```bash
# Single-node multi-NPU
python scripts/train_m3.py \
--train_data data/train.jsonl \
--output_dir ./output/ascend \
--device npu
# With specific NPU selection
export ASCEND_RT_VISIBLE_DEVICES=0,1,2,3
python scripts/train_m3.py \
--train_data data/train.jsonl \
--output_dir ./output/ascend
```
### Configuration Options
```toml
# config.toml
[distributed]
backend = "auto" # auto, nccl, gloo, hccl
timeout_minutes = 30
find_unused_parameters = false
gradient_as_bucket_view = true
broadcast_buffers = true
bucket_cap_mb = 25
use_distributed_optimizer = false
# Performance optimizations
[distributed.optimization]
overlap_comm = true
gradient_compression = false
compression_type = "none" # none, powerSGD, fp16
all_reduce_fusion = true
fusion_threshold_mb = 64
# Fault tolerance
[distributed.fault_tolerance]
enable_checkpointing = true
checkpoint_interval = 1000
restart_on_failure = true
max_restart_attempts = 3
heartbeat_interval = 60
```
### Advanced Features
#### 1. Gradient Accumulation with Distributed Training
```python
# Effective batch size = per_device_batch * gradient_accumulation * num_gpus
# Example: 8 * 4 * 4 = 128 effective batch size
python scripts/train_m3.py \
--per_device_train_batch_size 8 \
--gradient_accumulation_steps 4 \
--distributed
```
#### 2. Mixed Precision with Distributed
```bash
# Automatic mixed precision for faster training
torchrun --nproc_per_node=4 scripts/train_m3.py \
--bf16 \
--tf32 # Additional optimization for Ampere GPUs
```
#### 3. DeepSpeed Integration
```bash
# With DeepSpeed for extreme scale
deepspeed scripts/train_m3.py \
--deepspeed deepspeed_config.json \
--train_data data/train.jsonl
```
Example DeepSpeed config:
```json
{
"train_batch_size": "auto",
"gradient_accumulation_steps": "auto",
"fp16": {
"enabled": "auto",
"loss_scale": 0,
"loss_scale_window": 1000
},
"zero_optimization": {
"stage": 2,
"offload_optimizer": {
"device": "cpu",
"pin_memory": true
},
"allgather_partitions": true,
"allgather_bucket_size": 2e8,
"overlap_comm": true,
"reduce_scatter": true,
"reduce_bucket_size": 2e8,
"contiguous_gradients": true
}
}
```
#### 4. Fault Tolerance and Recovery
The distributed training system includes robust error handling:
```python
# Training automatically recovers from:
# - Temporary network failures
# - GPU memory errors (with retry)
# - Node failures (with checkpoint recovery)
# - Communication timeouts
```
#### 5. Performance Monitoring
```bash
# Enable distributed profiling
python scripts/train_m3.py \
--enable_profiling \
--profiling_start_step 100 \
--profiling_end_step 200
# Monitor with tensorboard
tensorboard --logdir ./logs --bind_all
```
### Best Practices
1. **Data Loading**
```python
# Use distributed sampler (automatic)
# Set num_workers based on CPUs
--dataloader_num_workers 4
```
2. **Batch Size Scaling**
```python
# Scale learning rate with batch size
# LR = base_lr * (total_batch_size / base_batch_size)
--learning_rate 1e-5 * (num_gpus * batch_size / 16)
```
3. **Communication Optimization**
```python
# Reduce communication overhead
--gradient_accumulation_steps 4
--broadcast_buffers false # If not needed
```
4. **Memory Optimization**
```python
# For large models
--gradient_checkpointing
--zero_stage 2 # With DeepSpeed
```
### Troubleshooting
#### NCCL Errors
```bash
# Debug NCCL issues
export NCCL_DEBUG=INFO
export NCCL_DEBUG_SUBSYS=ALL
# Common fixes
export NCCL_IB_DISABLE=1 # Disable InfiniBand
export NCCL_P2P_DISABLE=1 # Disable P2P
```
#### Hanging Training
```bash
# Set timeout for debugging
export NCCL_TIMEOUT=300 # 5 minutes
export NCCL_ASYNC_ERROR_HANDLING=1
```
#### Memory Issues
```python
# Clear cache periodically
--clear_cache_interval 100
# Use gradient checkpointing
--gradient_checkpointing
```
### Performance Benchmarks
Typical speedup with distributed training:
- 2 GPUs: 1.8-1.9x speedup
- 4 GPUs: 3.5-3.8x speedup
- 8 GPUs: 6.5-7.5x speedup
Factors affecting performance:
- Model size (larger models scale better)
- Batch size (larger batches reduce communication overhead)
- Network bandwidth (especially for multi-node)
- Data loading speed
---
## 📊 Evaluation Metrics
### Overview
The evaluation system provides comprehensive metrics for both retrieval and reranking tasks, with support for graded relevance, statistical significance testing, and performance benchmarking.
### Core Metrics
#### Retrieval Metrics
1. **Recall@K** - Fraction of relevant documents retrieved in top K
```python
recall = |retrieved_relevant| / |total_relevant|
```
2. **Precision@K** - Fraction of retrieved documents that are relevant
```python
precision = |retrieved_relevant| / K
```
3. **Mean Average Precision (MAP)** - Average precision at each relevant position
```python
MAP = mean([AP(q) for q in queries])
where AP = sum([P@k * rel(k) for k in positions]) / |relevant|
```
4. **Mean Reciprocal Rank (MRR@K)** - Average reciprocal of first relevant result rank
```python
MRR = mean([1/rank_of_first_relevant for q in queries])
```
5. **Normalized Discounted Cumulative Gain (NDCG@K)** - Position-weighted relevance score
```python
DCG@K = sum([rel(i) / log2(i + 1) for i in 1..K])
NDCG@K = DCG@K / IDCG@K
```
6. **Success Rate@K** - Fraction of queries with at least one relevant result in top K
```python
success_rate = |queries_with_relevant_in_topK| / |total_queries|
```
7. **Coverage@K** - Fraction of all relevant documents found across all queries
```python
coverage = |unique_relevant_retrieved| / |total_unique_relevant|
```
#### Reranking Metrics
1. **Accuracy@1** - Top-1 ranking accuracy
2. **MRR** - Mean reciprocal rank for reranked results
3. **NDCG@K** - Ranking quality with graded relevance
4. **ERR (Expected Reciprocal Rank)** - Cascade model metric
### Usage Examples
#### Basic Evaluation
```bash
# Evaluate retriever
python scripts/evaluate.py \
--eval_type retriever \
--model_path ./output/bge-m3-finetuned \
--eval_data data/test.jsonl \
--k_values 1 5 10 20 50 \
--output_file results/retriever_metrics.json
# Evaluate reranker
python scripts/evaluate.py \
--eval_type reranker \
--model_path ./output/bge-reranker-finetuned \
--eval_data data/test.jsonl \
--output_file results/reranker_metrics.json
# Evaluate full pipeline
python scripts/evaluate.py \
--eval_type pipeline \
--retriever_path ./output/bge-m3-finetuned \
--reranker_path ./output/bge-reranker-finetuned \
--eval_data data/test.jsonl \
--retrieval_top_k 50 \
--rerank_top_k 10
```
#### Advanced Evaluation with Graded Relevance
```python
# Data format with relevance scores
{
"query": "machine learning basics",
"documents": [
{"text": "Introduction to ML", "relevance": 2}, # Highly relevant
{"text": "Deep learning guide", "relevance": 1}, # Somewhat relevant
{"text": "Database systems", "relevance": 0} # Not relevant
]
}
```
```bash
# Evaluate with graded relevance
python scripts/evaluate.py \
--eval_type retriever \
--model_path ./output/bge-m3-finetuned \
--eval_data data/test_graded.jsonl \
--use_graded_relevance \
--relevance_threshold 1 \
--ndcg_k_values 5 10 20
```
#### Statistical Significance Testing
```bash
# Compare two models with significance testing
python scripts/compare_retriever.py \
--baseline_model_path BAAI/bge-m3 \
--finetuned_model_path ./output/bge-m3-finetuned \
--data_path data/test.jsonl \
--num_bootstrap_samples 1000 \
--confidence_level 0.95
# Output includes:
# - Metric comparisons with confidence intervals
# - Statistical significance (p-values)
# - Effect sizes (Cohen's d)
```
### Metric Computation Details
#### Handling Edge Cases
1. **Empty Results**
- Recall: 0 if no relevant documents exist, else 0
- Precision: 0 for empty results
- MRR: 0 if no relevant document found
- NDCG: 0 for empty results
2. **Tied Scores**
- Uses stable sorting to maintain consistency
- Averages metrics across tied positions
3. **Missing Relevance Labels**
- Binary relevance: treats unlabeled as non-relevant
- Graded relevance: requires explicit labels
#### Efficiency Optimizations
```python
# Batch evaluation for large datasets
evaluator = Evaluator(
model=model,
batch_size=128,
use_fp16=True,
device="cuda"
)
# Streaming evaluation for memory efficiency
results = evaluator.evaluate_streaming(
data_path="large_dataset.jsonl",
chunk_size=10000
)
```
### Performance Benchmarking
```bash
# Benchmark inference speed
python scripts/benchmark.py \
--model_type retriever \
--model_path ./output/bge-m3-finetuned \
--data_path data/test.jsonl \
--batch_sizes 1 8 16 32 64 \
--num_warmup_steps 10 \
--num_benchmark_steps 100
# Output includes:
# - Throughput (queries/second)
# - Latency percentiles (p50, p90, p95, p99)
# - Memory usage
# - Hardware utilization
```
### Custom Metrics
```python
from evaluation.metrics import BaseMetric
class CustomMetric(BaseMetric):
def compute(self, predictions, references, k=10):
# Your custom metric logic
return score
# Register custom metric
from evaluation.evaluator import Evaluator
evaluator = Evaluator(model=model)
evaluator.add_metric("custom_metric", CustomMetric())
```
### Visualization and Reporting
```python
# Generate evaluation report
from evaluation.reporting import EvaluationReport
report = EvaluationReport(results)
report.generate_html("evaluation_report.html")
report.generate_latex("evaluation_report.tex")
report.plot_metrics("plots/")
# Example plots:
# - Recall@K curves
# - Precision-Recall curves
# - NDCG@K comparison
# - Query-level performance distribution
```
### Best Practices
1. **Use Appropriate K Values**
- Small K (1, 5, 10) for user-facing metrics
- Large K (50, 100) for reranking scenarios
2. **Consider Query Types**
- Navigational queries: Focus on MRR, Success@1
- Informational queries: Focus on NDCG, Recall@K
3. **Statistical Rigor**
- Always use test sets with 1000+ queries
- Report confidence intervals
- Use paired tests for model comparison
4. **Graded Relevance**
- Use NDCG for nuanced evaluation
- Define clear relevance guidelines
- Consider query-specific relevance scales
---
## 📈 Model Evaluation & Benchmarking
### 1. Detailed Accuracy Evaluation
Use `scripts/evaluate.py` for comprehensive accuracy evaluation of retriever, reranker, or the full pipeline. Supports all standard metrics (MRR, Recall@k, MAP, NDCG, etc.) and baseline comparison.
**Example:**
```bash
python scripts/evaluate.py --retriever_model path/to/finetuned --eval_data path/to/data.jsonl --eval_retriever --k_values 1 5 10
python scripts/evaluate.py --reranker_model path/to/finetuned --eval_data path/to/data.jsonl --eval_reranker
python scripts/evaluate.py --retriever_model path/to/finetuned --reranker_model path/to/finetuned --eval_data path/to/data.jsonl --eval_pipeline --compare_baseline --base_retriever path/to/baseline
```
### 2. Quick Performance Benchmarking
Use `scripts/benchmark.py` for fast, single-model performance profiling (throughput, latency, memory). This script does **not** report accuracy metrics.
**Example:**
```bash
python scripts/benchmark.py --model_type retriever --model_path path/to/model --data_path path/to/data.jsonl --batch_size 16 --device cuda
python scripts/benchmark.py --model_type reranker --model_path path/to/model --data_path path/to/data.jsonl --batch_size 16 --device cuda
```
### 3. Fine-tuned vs. Baseline Comparison (Accuracy & Performance)
Use `scripts/compare_retriever.py` and `scripts/compare_reranker.py` to compare a fine-tuned model against a baseline, reporting both accuracy and performance metrics side by side, including the delta (improvement) for each metric.
**Example:**
```bash
python scripts/compare_retriever.py --finetuned_model_path path/to/finetuned --baseline_model_path path/to/baseline --data_path path/to/data.jsonl --batch_size 16 --device cuda
python scripts/compare_reranker.py --finetuned_model_path path/to/finetuned --baseline_model_path path/to/baseline --data_path path/to/data.jsonl --batch_size 16 --device cuda
```
- Results are saved to a file and include a table of metrics for both models and the improvement (delta).
---
## 🖥️ Ascend NPU Section
- **Install**: `pip install torch_npu` (see [Ascend docs](https://gitee.com/ascend/pytorch))
- **Edit `config.toml`:**
```toml
[ascend]
device_type = "npu"
backend = "hccl"
mixed_precision = true
visible_devices = "0,1,2,3"
```
- **Launch training:**
```bash
python3.10 scripts/train_m3.py --train_data data/train.jsonl --output_dir ./output/npu
```
- **Distributed:** Use Ascend's distributed launcher or set `backend = "hccl"` in `[ascend]`.
- **All device/backend selection is config-driven.**
---
## 🚀 Performance Optimization
### Multi-GPU Training
```bash
torchrun --nproc_per_node=2 scripts/train_m3.py \
--train_data data/train.jsonl \
--output_dir ./output/distributed \
--fp16 \
--gradient_checkpointing
```
### Memory Optimization
- **Gradient Checkpointing**: Reduces memory usage
- **Mixed Precision**: FP16 training for RTX 5090
- **Gradient Accumulation**: Effective larger batch sizes
- **DeepSpeed Integration**: For large-scale training
### Huawei Ascend Compatibility
```python
# For Ascend deployment
import torch_npu
device = torch.device("npu")
model.to(device)
```
---
## 📝 **Recent Changes & Improvements**
### ✅ **Simplified Data Pipeline (Latest)**
- **Removed unnecessary PKL conversion pipeline**: The `data/optimization.py` module has been removed
- Performance benefits came from pre-tokenization, not PKL format
- PKL files caused memory issues and added complexity without real benefits
- **Added in-memory caching option**: For small datasets (< 100k samples), enable `cache_in_memory=True`
- **Kept JSONL as primary format**: Industry standard, human-readable, and efficient
- **All core functionality intact**: Training, evaluation, and model features unchanged
### 🚀 **Performance Tips**
For better training performance:
```python
# Enable in-memory caching for small datasets
dataset = BGEM3Dataset(
data_path="train.jsonl",
tokenizer=tokenizer,
cache_in_memory=True, # Pre-tokenize and cache in memory
max_cache_size=100000 # Maximum samples to cache
)
```
For more optimization options, run: `python -m data.optimization --help`
---
## 🏗 Architecture
```
bge_finetune/
├── config.toml # Main configuration file
├── requirements.txt # Python dependencies
├── README.md # Project documentation
├── __init__.py # Package initialization
├── config/ # Configuration management
│ ├── __init__.py
│ ├── base_config.py # Base configuration class with validation
│ ├── m3_config.py # BGE-M3 specific configuration
│ └── reranker_config.py # BGE-Reranker specific configuration
├── data/ # Data processing and management
│ ├── __init__.py
│ ├── dataset.py # Dataset classes (Contrastive, Reranker, Joint)
│ ├── preprocessing.py # Format conversion and validation
│ ├── augmentation.py # Query/passage augmentation strategies
│ ├── hard_negative_mining.py # ANCE-style hard negative mining
│ └── optimization.py # Data optimization utilities
├── models/ # Model implementations
│ ├── __init__.py
│ ├── bge_m3.py # BGE-M3 wrapper with multi-functionality
│ ├── bge_reranker.py # BGE-Reranker cross-encoder wrapper
│ └── losses.py # Loss functions (InfoNCE, ListwiseCE, etc.)
├── training/ # Training logic
│ ├── __init__.py
│ ├── trainer.py # Base trainer with error recovery
│ ├── m3_trainer.py # BGE-M3 specialized trainer
│ ├── reranker_trainer.py # Reranker specialized trainer
│ └── joint_trainer.py # RocketQAv2-style joint training
├── evaluation/ # Evaluation and metrics
│ ├── __init__.py
│ ├── evaluator.py # Evaluation pipeline
│ ├── metrics.py # Metric implementations (MRR, NDCG, etc.)
│ └── reporting.py # Result visualization and reporting
├── utils/ # Utilities and helpers
│ ├── __init__.py
│ ├── logging.py # Distributed-aware logging with tqdm
│ ├── checkpoint.py # Atomic checkpoint save/load
│ ├── distributed.py # Distributed training utilities
│ ├── config_loader.py # TOML config loading with precedence
│ ├── memory_utils.py # Memory optimization utilities
│ └── ascend_support.py # Huawei Ascend NPU support
├── scripts/ # Entry point scripts
│ ├── __init__.py
│ ├── train_m3.py # Train BGE-M3 embeddings
│ ├── train_reranker.py # Train BGE-Reranker
│ ├── train_joint.py # Joint training (RocketQAv2)
│ ├── evaluate.py # Comprehensive evaluation
│ ├── benchmark.py # Performance benchmarking
│ ├── compare_retriever.py # Retriever comparison
│ ├── compare_reranker.py # Reranker comparison
│ ├── test_installation.py # Environment verification
│ └── setup.py # Initial setup script
├── cache/ # Cache directory (auto-created)
│ ├── models/ # Downloaded model cache
│ └── data/ # Preprocessed data cache
├── logs/ # Logging directory (auto-created)
│ ├── training/ # Training logs
│ ├── evaluation/ # Evaluation results
│ └── errors.log # Error logs
├── output/ # Training outputs (auto-created)
│ ├── bge-m3-finetuned/ # Fine-tuned models
│ ├── bge-reranker-finetuned/ # Fine-tuned rerankers
│ └── checkpoints/ # Training checkpoints
└── docs/ # Additional documentation
├── API.md # API reference
├── CONTRIBUTING.md # Contribution guidelines
└── USAGE_GUIDE.md # Detailed usage guide
```
### Module Descriptions
#### Core Modules
- **`config/`**: Hierarchical configuration system with TOML support and validation
- Automatic type checking and value validation
- Clear precedence rules (CLI > Model > Hardware > Training > Defaults)
- Dynamic configuration based on hardware detection
- **`data/`**: Comprehensive data handling
- Format detection and conversion (JSONL, JSON, CSV, TSV, XML, DuReader, MS MARCO)
- Streaming support for large datasets
- Built-in augmentation with LLM and local model support
- ANCE-style hard negative mining with FAISS integration
- **`models/`**: Model wrappers and loss functions
- BGE-M3: Dense, sparse (SPLADE), and multi-vector (ColBERT) support
- BGE-Reranker: Cross-encoder with knowledge distillation
- Loss functions: InfoNCE (τ=0.02), ListwiseCE, RankNet, LambdaRank
- **`training/`**: Robust training infrastructure
- Error recovery with configurable failure thresholds
- Automatic mixed precision (fp16/bf16) based on hardware
- Gradient checkpointing and memory optimization
- Distributed training with automatic backend selection
- **`evaluation/`**: Comprehensive evaluation suite
- Metrics: Recall@K, MRR@K, NDCG@K, MAP, Success Rate, Coverage
- Graded relevance support
- Statistical significance testing
- Performance benchmarking
- **`utils/`**: Production-ready utilities
- TqdmLoggingHandler for progress tracking
- Atomic checkpoint saving with backup
- Automatic distributed backend detection (NCCL/HCCL/Gloo)
- Full Ascend NPU support module
#### Entry Points
- **Training Scripts**:
- `train_m3.py`: BGE-M3 training with all features
- `train_reranker.py`: Reranker training with distillation
- `train_joint.py`: RocketQAv2-style joint training
- **Evaluation Scripts**:
- `evaluate.py`: Full evaluation with all metrics
- `benchmark.py`: Performance profiling
- `compare_*.py`: Side-by-side model comparison
- **Utility Scripts**:
- `test_installation.py`: Verify environment setup
- `setup.py`: Download models and prepare environment
### Key Features by Module
1. **Automatic Hardware Detection**
- CUDA capability checking for optimal dtype (fp16/bf16)
- Ascend NPU detection and configuration
- CPU fallback for development
2. **Robust Error Handling**
- Training continues despite batch failures
- Automatic checkpoint recovery
- Comprehensive error logging
3. **Memory Efficiency**
- Gradient checkpointing
- Periodic cache clearing
- Streaming data loading
- Mixed precision training
4. **Production Features**
- Atomic checkpoint saves
- Distributed training support
- Comprehensive logging
- Configuration validation
---
## 🔧 Troubleshooting & FAQ
**Q: My GPU is not detected or torch reports CUDA version mismatch.**
- Ensure you installed torch >=2.7.1 with CUDA 12.8+ support (see installation section).
- Run `nvidia-smi` to check your driver and CUDA version.
**Q: How do I enable Ascend NPU support?**
- Install `torch_npu` and set `device_type = "npu"` in `[hardware]` or `[ascend]` in `config.toml`.
- See [Ascend documentation](https://gitee.com/ascend/pytorch) for details.
**Q: How do config parameters interact?**
- `[hardware]` > `[training]` for overlapping parameters.
- `[m3]`/`[reranker]` > `[training]`/`[hardware]` for their models.
- See config class docstrings and the config.toml for details.
**Q: How do I run tests?**
- Install `pytest` and run `pytest tests/` (if tests are present).
- Or run `python scripts/test_installation.py` for a full environment check.
---
## 📚 References
1. **BGE M3-Embedding**: [arXiv:2402.03216](https://arxiv.org/abs/2402.03216)
2. **RocketQAv2**: [ACL Anthology](https://aclanthology.org/2021.emnlp-main.224/)
3. **ANCE**: [arXiv:2007.00808](https://arxiv.org/abs/2007.00808)
4. **ColBERT**: [SIGIR 2020](https://dl.acm.org/doi/10.1145/3397271.3401075)
5. **SPLADE**: [arXiv:2109.10086](https://arxiv.org/abs/2109.10086)
---
## 🔮 Roadmap
- [ ] Knowledge Graph integration
- [ ] Multi-modal retrieval support
- [ ] Additional language support
- [ ] Distributed training optimization
- [ ] Model quantization for deployment
- [ ] Benchmarking on standard datasets