298 lines
7.4 KiB
Markdown
298 lines
7.4 KiB
Markdown
# BGE Fine-tuning Usage Guide
|
|
|
|
## Overview
|
|
|
|
This guide provides detailed instructions on fine-tuning BGE (BAAI General Embedding) models using our enhanced training framework. The framework supports both BGE-M3 (embedding) and BGE-reranker (cross-encoder) models with state-of-the-art training techniques.
|
|
|
|
## Table of Contents
|
|
|
|
1. [Installation](#installation)
|
|
2. [Quick Start](#quick-start)
|
|
3. [Data Preparation](#data-preparation)
|
|
4. [Training Scripts](#training-scripts)
|
|
5. [Advanced Usage](#advanced-usage)
|
|
6. [Troubleshooting](#troubleshooting)
|
|
|
|
## Installation
|
|
|
|
```bash
|
|
# Clone the repository
|
|
git clone https://github.com/yourusername/bge-finetune.git
|
|
cd bge-finetune
|
|
|
|
# Install dependencies
|
|
pip install -r requirements.txt
|
|
|
|
# Optional: Install Ascend NPU support
|
|
# pip install torch-npu # If using Huawei Ascend NPUs
|
|
```
|
|
|
|
## Quick Start
|
|
|
|
### 1. Prepare Your Data
|
|
|
|
The framework supports multiple data formats:
|
|
- **JSONL** (recommended): One JSON object per line
|
|
- **CSV/TSV**: Tabular data with headers
|
|
- **JSON**: Single JSON file with array of samples
|
|
|
|
Example JSONL format for embedding model:
|
|
```json
|
|
{"query": "What is machine learning?", "pos": ["ML is a subset of AI..."], "neg": ["The weather today..."]}
|
|
{"query": "How to cook pasta?", "pos": ["Boil water and add pasta..."], "neg": ["Machine learning uses..."]}
|
|
```
|
|
|
|
### 2. Fine-tune BGE-M3 (Embedding Model)
|
|
|
|
```bash
|
|
python scripts/train_m3.py \
|
|
--model_name_or_path BAAI/bge-m3 \
|
|
--train_data data/train.jsonl \
|
|
--output_dir ./output/bge-m3-finetuned \
|
|
--num_train_epochs 3 \
|
|
--per_device_train_batch_size 16 \
|
|
--learning_rate 1e-5
|
|
```
|
|
|
|
### 3. Fine-tune BGE-Reranker
|
|
|
|
```bash
|
|
python scripts/train_reranker.py \
|
|
--model_name_or_path BAAI/bge-reranker-base \
|
|
--train_data data/train_pairs.jsonl \
|
|
--output_dir ./output/bge-reranker-finetuned \
|
|
--num_train_epochs 3 \
|
|
--per_device_train_batch_size 32 \
|
|
--learning_rate 2e-5
|
|
```
|
|
|
|
## Data Preparation
|
|
|
|
### Data Formats
|
|
|
|
#### 1. Embedding Model (Triplets Format)
|
|
|
|
```json
|
|
{
|
|
"query": "query text",
|
|
"pos": ["positive passage 1", "positive passage 2"],
|
|
"neg": ["negative passage 1", "negative passage 2"],
|
|
"pos_scores": [1.0, 0.9], // Optional: relevance scores
|
|
"neg_scores": [0.1, 0.2] // Optional: relevance scores
|
|
}
|
|
```
|
|
|
|
#### 2. Reranker Model (Pairs Format)
|
|
|
|
```json
|
|
{
|
|
"query": "query text",
|
|
"passage": "passage text",
|
|
"label": 1, // 1 for relevant, 0 for not relevant
|
|
"score": 0.95 // Optional: relevance score
|
|
}
|
|
```
|
|
|
|
### Performance Optimization
|
|
|
|
For small datasets (< 100k samples), enable in-memory caching:
|
|
|
|
```python
|
|
# In your custom script
|
|
dataset = BGEM3Dataset(
|
|
data_path="train.jsonl",
|
|
tokenizer=tokenizer,
|
|
cache_in_memory=True, # Pre-tokenize and cache in memory
|
|
max_cache_size=100000 # Maximum samples to cache
|
|
)
|
|
```
|
|
|
|
## Training Scripts
|
|
|
|
### BGE-M3 Training Options
|
|
|
|
```bash
|
|
python scripts/train_m3.py \
|
|
--model_name_or_path BAAI/bge-m3 \
|
|
--train_data data/train.jsonl \
|
|
--eval_data data/eval.jsonl \
|
|
--output_dir ./output/bge-m3-finetuned \
|
|
--num_train_epochs 3 \
|
|
--per_device_train_batch_size 16 \
|
|
--per_device_eval_batch_size 32 \
|
|
--learning_rate 1e-5 \
|
|
--warmup_ratio 0.1 \
|
|
--gradient_accumulation_steps 2 \
|
|
--fp16 \ # Enable mixed precision training
|
|
--train_group_size 8 \ # Number of passages per query
|
|
--use_hard_negatives \ # Enable hard negative mining
|
|
--temperature 0.02 \ # Contrastive loss temperature
|
|
--save_steps 500 \
|
|
--eval_steps 500 \
|
|
--logging_steps 100
|
|
```
|
|
|
|
### BGE-Reranker Training Options
|
|
|
|
```bash
|
|
python scripts/train_reranker.py \
|
|
--model_name_or_path BAAI/bge-reranker-base \
|
|
--train_data data/train_pairs.jsonl \
|
|
--eval_data data/eval_pairs.jsonl \
|
|
--output_dir ./output/bge-reranker-finetuned \
|
|
--num_train_epochs 3 \
|
|
--per_device_train_batch_size 32 \
|
|
--learning_rate 2e-5 \
|
|
--max_length 512 \
|
|
--train_group_size 16 \ # Pairs per query
|
|
--gradient_checkpointing \ # Save memory
|
|
--logging_steps 50
|
|
```
|
|
|
|
### Joint Training (RocketQAv2 Approach)
|
|
|
|
Train retriever and reranker together:
|
|
|
|
```bash
|
|
python scripts/train_joint.py \
|
|
--retriever_model BAAI/bge-m3 \
|
|
--reranker_model BAAI/bge-reranker-base \
|
|
--train_data data/train.jsonl \
|
|
--output_dir ./output/joint-training \
|
|
--num_train_epochs 3 \
|
|
--alternate_steps 100 # Switch between models every N steps
|
|
```
|
|
|
|
## Advanced Usage
|
|
|
|
### Configuration Files
|
|
|
|
Use TOML configuration for complex setups:
|
|
|
|
```toml
|
|
# config.toml
|
|
[training]
|
|
default_batch_size = 16
|
|
default_num_epochs = 3
|
|
gradient_accumulation_steps = 2
|
|
|
|
[m3]
|
|
model_name_or_path = "BAAI/bge-m3"
|
|
train_group_size = 8
|
|
temperature = 0.02
|
|
|
|
[hardware]
|
|
device_type = "cuda" # or "npu" for Ascend
|
|
```
|
|
|
|
Run with config:
|
|
```bash
|
|
python scripts/train_m3.py --config_path config.toml --train_data data/train.jsonl
|
|
```
|
|
|
|
### Multi-GPU Training
|
|
|
|
```bash
|
|
# Single node, multiple GPUs
|
|
torchrun --nproc_per_node=4 scripts/train_m3.py \
|
|
--train_data data/train.jsonl \
|
|
--output_dir ./output/bge-m3-multigpu
|
|
|
|
# Multiple nodes
|
|
torchrun --nnodes=2 --nproc_per_node=4 \
|
|
--rdzv_id=100 --rdzv_backend=c10d \
|
|
--rdzv_endpoint=$MASTER_ADDR:$MASTER_PORT \
|
|
scripts/train_m3.py --train_data data/train.jsonl
|
|
```
|
|
|
|
### Custom Data Preprocessing
|
|
|
|
```python
|
|
from data.preprocessing import DataPreprocessor
|
|
|
|
# Preprocess custom format data
|
|
preprocessor = DataPreprocessor(
|
|
tokenizer=tokenizer,
|
|
max_query_length=64,
|
|
max_passage_length=512
|
|
)
|
|
|
|
# Convert and clean data
|
|
output_path, val_path = preprocessor.preprocess_file(
|
|
input_path="raw_data.csv",
|
|
output_path="processed_data.jsonl",
|
|
file_format="csv",
|
|
validation_split=0.1
|
|
)
|
|
```
|
|
|
|
### Model Evaluation
|
|
|
|
```bash
|
|
python scripts/evaluate.py \
|
|
--retriever_model ./output/bge-m3-finetuned \
|
|
--reranker_model ./output/bge-reranker-finetuned \
|
|
--eval_data data/test.jsonl \
|
|
--output_dir ./evaluation_results \
|
|
--metrics ndcg mrr recall \
|
|
--top_k 10 20 50
|
|
```
|
|
|
|
## Troubleshooting
|
|
|
|
### Common Issues
|
|
|
|
1. **Out of Memory**
|
|
- Reduce batch size
|
|
- Enable gradient checkpointing: `--gradient_checkpointing`
|
|
- Use fp16 training: `--fp16`
|
|
- Enable in-memory caching for small datasets
|
|
|
|
2. **Slow Training**
|
|
- Increase number of data loader workers: `--dataloader_num_workers 8`
|
|
- Enable in-memory caching for datasets < 100k samples
|
|
- Use SSD for data storage
|
|
|
|
3. **Poor Performance**
|
|
- Check data quality and format
|
|
- Adjust learning rate and warmup
|
|
- Use hard negative mining: `--use_hard_negatives`
|
|
- Increase training epochs
|
|
|
|
### Debugging Tips
|
|
|
|
```bash
|
|
# Enable debug logging
|
|
export LOG_LEVEL=DEBUG
|
|
python scripts/train_m3.py ...
|
|
|
|
# Profile training
|
|
python scripts/benchmark.py \
|
|
--model_type retriever \
|
|
--model_path ./output/bge-m3-finetuned \
|
|
--data_path data/test.jsonl
|
|
```
|
|
|
|
## Best Practices
|
|
|
|
1. **Data Quality**
|
|
- Ensure balanced positive/negative samples
|
|
- Use relevance scores when available
|
|
- Clean and normalize text data
|
|
|
|
2. **Training Strategy**
|
|
- Start with small learning rate (1e-5 to 2e-5)
|
|
- Use warmup (10% of steps)
|
|
- Monitor evaluation metrics
|
|
|
|
3. **Resource Optimization**
|
|
- Use gradient accumulation for large batches
|
|
- Enable mixed precision training
|
|
- Consider in-memory caching for small datasets
|
|
|
|
## Additional Resources
|
|
|
|
- [Data Format Examples](data_formats.md)
|
|
- [Model Architecture Details](../models/README.md)
|
|
- [Evaluation Metrics Guide](../evaluation/README.md)
|
|
- [API Reference](api_reference.md) |