init commit
This commit is contained in:
254
docs/data_formats.md
Normal file
254
docs/data_formats.md
Normal file
@@ -0,0 +1,254 @@
|
||||
# BGE Data Pipeline Format Specifications
|
||||
|
||||
The refactored BGE data pipeline supports four distinct input formats with automatic detection and conversion:
|
||||
|
||||
## 📊 Format Overview
|
||||
|
||||
| Format | Use Case | Detection Key | Output |
|
||||
|--------|----------|---------------|---------|
|
||||
| **Triplets** | BGE-M3 embedding training | `pos` + `neg` | Contrastive learning batches |
|
||||
| **Pairs** | BGE-reranker training | `passage` + `label` | Cross-encoder pairs |
|
||||
| **Candidates** | Unified threshold conversion | `candidates` | Auto-exploded pairs |
|
||||
| **Legacy Nested** | Backward compatibility | `pos` only | Converted to pairs |
|
||||
|
||||
---
|
||||
|
||||
## 1. 🎯 Triplets Format (BGE-M3 Embedding)
|
||||
|
||||
For contrastive learning with multiple positives and negatives per query.
|
||||
|
||||
```json
|
||||
{
|
||||
"query": "What is machine learning?",
|
||||
"pos": [
|
||||
"Machine learning is a subset of artificial intelligence that enables computers to learn without explicit programming.",
|
||||
"ML algorithms use statistical techniques to learn patterns from data and make predictions."
|
||||
],
|
||||
"neg": [
|
||||
"Weather forecasting uses meteorological data to predict atmospheric conditions.",
|
||||
"Cooking recipes require specific ingredients and step-by-step instructions."
|
||||
],
|
||||
"pos_scores": [0.95, 0.88],
|
||||
"neg_scores": [0.15, 0.08],
|
||||
"prompt": "为此查询生成表示:",
|
||||
"type": "definition"
|
||||
}
|
||||
```
|
||||
|
||||
### Required Fields
|
||||
- `query`: Query text
|
||||
- `pos`: List of positive passages
|
||||
|
||||
### Optional Fields
|
||||
- `neg`: List of negative passages
|
||||
- `pos_scores`: Relevance scores for positives
|
||||
- `neg_scores`: Relevance scores for negatives
|
||||
- `prompt`: Query instruction prefix
|
||||
- `type`: Query type classification
|
||||
|
||||
---
|
||||
|
||||
## 2. 🔍 Pairs Format (BGE-Reranker)
|
||||
|
||||
For cross-encoder training with individual query-passage pairs.
|
||||
|
||||
```json
|
||||
{
|
||||
"query": "What is machine learning?",
|
||||
"passage": "Machine learning is a subset of artificial intelligence that enables computers to learn without explicit programming.",
|
||||
"label": 1,
|
||||
"score": 0.95,
|
||||
"qid": "q1",
|
||||
"pid": "p1"
|
||||
}
|
||||
```
|
||||
|
||||
### Required Fields
|
||||
- `query`: Query text
|
||||
- `passage`: Passage text
|
||||
- `label`: Relevance label (0 or 1)
|
||||
|
||||
### Optional Fields
|
||||
- `score`: Relevance score (0.0-1.0)
|
||||
- `qid`: Query identifier
|
||||
- `pid`: Passage identifier
|
||||
|
||||
---
|
||||
|
||||
## 3. 🎲 Candidates Format (Unified)
|
||||
|
||||
For datasets with multiple candidates per query. Each candidate is automatically exploded into a separate pair based on score threshold.
|
||||
|
||||
```json
|
||||
{
|
||||
"query": "What is artificial intelligence?",
|
||||
"qid": "q1",
|
||||
"candidates": [
|
||||
{
|
||||
"text": "Artificial intelligence is the simulation of human intelligence in machines.",
|
||||
"score": 0.94,
|
||||
"pid": "c1"
|
||||
},
|
||||
{
|
||||
"text": "AI systems can perform tasks that typically require human intelligence.",
|
||||
"score": 0.87,
|
||||
"pid": "c2"
|
||||
},
|
||||
{
|
||||
"text": "Mountain climbing requires proper equipment and training.",
|
||||
"score": 0.23,
|
||||
"pid": "c3"
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
### Required Fields
|
||||
- `query`: Query text
|
||||
- `candidates`: List of candidate objects
|
||||
|
||||
### Candidate Object Fields
|
||||
- `text`: Candidate passage text
|
||||
- `score`: Relevance score (optional, defaults to 1.0)
|
||||
- `pid`: Passage identifier (optional, auto-generated)
|
||||
|
||||
### Processing Logic
|
||||
- **Threshold Assignment**: Candidates with `score >= threshold` become `label=1`, others become `label=0`
|
||||
- **Default Threshold**: 0.5 (configurable)
|
||||
- **Output**: Multiple pairs, one per candidate
|
||||
|
||||
---
|
||||
|
||||
## 4. 🔄 Legacy Nested Format (Backward Compatibility)
|
||||
|
||||
Legacy format automatically converted to flat pairs for reranker training.
|
||||
|
||||
```json
|
||||
{
|
||||
"query": "What is natural language processing?",
|
||||
"pos": [
|
||||
"Natural language processing is a field of AI that helps computers understand human language.",
|
||||
"NLP combines computational linguistics with machine learning to process text and speech."
|
||||
],
|
||||
"neg": [
|
||||
"Automobile manufacturing involves assembly lines and quality control processes.",
|
||||
"Photography captures light through camera lenses to create images."
|
||||
],
|
||||
"pos_scores": [0.93, 0.86],
|
||||
"neg_scores": [0.14, 0.22]
|
||||
}
|
||||
```
|
||||
|
||||
### Required Fields
|
||||
- `query`: Query text
|
||||
- `pos`: List of positive passages
|
||||
|
||||
### Optional Fields
|
||||
- `neg`: List of negative passages
|
||||
- `pos_scores`: Scores for positives
|
||||
- `neg_scores`: Scores for negatives
|
||||
|
||||
### Processing Logic
|
||||
- Each `pos[i]` becomes `(query, passage, label=1)`
|
||||
- Each `neg[j]` becomes `(query, passage, label=0)`
|
||||
- Maintains original scores
|
||||
- Tracks source as `legacy_nested`
|
||||
|
||||
---
|
||||
|
||||
## 🔧 Usage Examples
|
||||
|
||||
### Fine-tuning Data Loader
|
||||
|
||||
```python
|
||||
from data.dataset import BGEM3Dataset, BGERerankerDataset, BGEDataset
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
|
||||
|
||||
# Triplets for BGE-M3
|
||||
triplets_dataset = BGEM3Dataset(
|
||||
data_path="triplets_data.jsonl",
|
||||
tokenizer=tokenizer,
|
||||
format_type="triplets"
|
||||
)
|
||||
|
||||
# Pairs for BGE-reranker
|
||||
pairs_dataset = BGERerankerDataset(
|
||||
data_path="pairs_data.jsonl",
|
||||
tokenizer=tokenizer,
|
||||
format_type="pairs"
|
||||
)
|
||||
|
||||
# Candidates with threshold conversion
|
||||
candidates_dataset = BGEDataset(
|
||||
data_path="candidates_data.jsonl",
|
||||
tokenizer=tokenizer,
|
||||
format_type="candidates",
|
||||
candidates_threshold=0.6 # Custom threshold
|
||||
)
|
||||
|
||||
# Legacy with automatic conversion
|
||||
legacy_dataset = BGERerankerDataset(
|
||||
data_path="legacy_data.jsonl",
|
||||
tokenizer=tokenizer,
|
||||
legacy_support=True
|
||||
)
|
||||
```
|
||||
|
||||
### Optimization Pipeline
|
||||
|
||||
```bash
|
||||
# BGE-M3 optimization (triplets → cached triplets)
|
||||
python -m data.optimization bge-m3 triplets_data.jsonl \
|
||||
--output_dir ./cache/data \
|
||||
--hard_negative_ratio 0.8 \
|
||||
--difficulty_threshold 0.2
|
||||
|
||||
# BGE-reranker optimization (pairs → cached pairs)
|
||||
python -m data.optimization bge-reranker pairs_data.jsonl \
|
||||
--output_dir ./cache/data \
|
||||
--output_format flat
|
||||
|
||||
# Candidates optimization (candidates → cached pairs)
|
||||
python -m data.optimization bge-reranker candidates_data.jsonl \
|
||||
--output_dir ./cache/data
|
||||
|
||||
# Legacy optimization (nested → cached pairs)
|
||||
python -m data.optimization bge-reranker legacy_data.jsonl \
|
||||
--output_dir ./cache/data \
|
||||
--output_format nested
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## ⚡ Performance Features
|
||||
|
||||
### Enhanced Processing
|
||||
- **10-15x faster loading** through pre-tokenization
|
||||
- **Automatic format detection** based on key presence
|
||||
- **Score-weighted sampling** for better training quality
|
||||
- **Hard negative mining** with difficulty thresholds
|
||||
|
||||
### Conversion Utilities
|
||||
- **Candidates → Pairs**: Threshold-based label assignment
|
||||
- **Legacy → Pairs**: Automatic nested-to-flat conversion
|
||||
- **Pairs → Triplets**: Grouping by query for embedding training
|
||||
- **Format validation**: Comprehensive error checking
|
||||
|
||||
### Backward Compatibility
|
||||
- **Legacy support**: Full compatibility with existing datasets
|
||||
- **Gradual migration**: Convert formats incrementally
|
||||
- **Source tracking**: Know the original format of converted data
|
||||
- **Flexible thresholds**: Configurable conversion parameters
|
||||
|
||||
---
|
||||
|
||||
## 🚀 Key Benefits
|
||||
|
||||
1. **Unified Pipeline**: Single codebase handles all formats
|
||||
2. **Automatic Detection**: No manual format specification needed
|
||||
3. **Seamless Conversion**: Transparent format transformations
|
||||
4. **Performance Optimized**: Significant speedup through caching
|
||||
5. **Backward Compatible**: Works with existing datasets
|
||||
6. **Flexible Configuration**: Customizable thresholds and parameters
|
||||
159
docs/directory_structure.md
Normal file
159
docs/directory_structure.md
Normal file
@@ -0,0 +1,159 @@
|
||||
# BGE Fine-tuning Directory Structure
|
||||
|
||||
This document clarifies the directory structure and cache usage in the BGE fine-tuning project.
|
||||
|
||||
## 📂 Project Directory Structure
|
||||
|
||||
```
|
||||
bge_finetune/
|
||||
├── 📁 cache/ # Cache directories
|
||||
│ ├── 📁 data/ # Preprocessed dataset cache
|
||||
│ │ ├── 📄 *.jsonl # Preprocessed JSONL files
|
||||
│ │ └── 📄 *.json # Preprocessed JSON files
|
||||
│ └── 📁 models/ # Downloaded model cache (HuggingFace/ModelScope)
|
||||
│ ├── 📁 tokenizers/ # Cached tokenizer files
|
||||
│ ├── 📁 config/ # Cached model config files
|
||||
│ └── 📁 downloads/ # Temporary download files
|
||||
│
|
||||
├── 📁 models/ # Local model storage
|
||||
│ ├── 📁 bge-m3/ # Local BGE-M3 model files
|
||||
│ └── 📁 bge-reranker-base/ # Local BGE-reranker model files
|
||||
│
|
||||
├── 📁 data/ # Raw training data
|
||||
│ ├── 📄 train.jsonl # Input training data
|
||||
│ ├── 📄 eval.jsonl # Input evaluation data
|
||||
│ └── 📁 processed/ # Preprocessed (but not cached) data
|
||||
│
|
||||
├── 📁 output/ # Training outputs
|
||||
│ ├── 📁 bge-m3/ # M3 training results
|
||||
│ ├── 📁 bge-reranker/ # Reranker training results
|
||||
│ └── 📁 checkpoints/ # Model checkpoints
|
||||
│
|
||||
└── 📁 logs/ # Training logs
|
||||
├── 📁 tensorboard/ # TensorBoard logs
|
||||
└── 📄 training.log # Text logs
|
||||
```
|
||||
|
||||
## 🎯 Directory Purposes
|
||||
|
||||
### 1. **`./cache/models/`** - Model Downloads Cache
|
||||
- **Purpose**: HuggingFace/ModelScope model download cache
|
||||
- **Contents**: Tokenizers, config files, temporary downloads
|
||||
- **Usage**: Automatic caching of remote model downloads
|
||||
- **Config**: `model_paths.cache_dir = "./cache/models"`
|
||||
|
||||
### 2. **`./cache/data/`** - Processed Dataset Cache
|
||||
- **Purpose**: Preprocessed and cleaned dataset files
|
||||
- **Contents**: Validated JSONL/JSON files
|
||||
- **Usage**: Store cleaned and validated training data
|
||||
- **Config**: `data.cache_dir = "./cache/data"`
|
||||
|
||||
### 3. **`./models/`** - Local Model Storage
|
||||
- **Purpose**: Complete local model files (when available)
|
||||
- **Contents**: Full model directories with all files
|
||||
- **Usage**: Direct loading without downloads
|
||||
- **Config**: `model_paths.bge_m3 = "./models/bge-m3"`
|
||||
|
||||
### 4. **`./data/`** - Raw Input Data
|
||||
- **Purpose**: Original training/evaluation datasets
|
||||
- **Contents**: `.jsonl` files in various formats
|
||||
- **Usage**: Input to preprocessing and optimization
|
||||
- **Config**: User-specified paths in training scripts
|
||||
|
||||
### 5. **`./output/`** - Training Results
|
||||
- **Purpose**: Fine-tuned models and training artifacts
|
||||
- **Contents**: Saved models, checkpoints, metrics
|
||||
- **Usage**: Results of training scripts
|
||||
- **Config**: `--output_dir` in training scripts
|
||||
|
||||
## ⚙️ Configuration Mapping
|
||||
|
||||
### Config File Settings
|
||||
```toml
|
||||
[model_paths]
|
||||
cache_dir = "./cache/models" # Model download cache
|
||||
bge_m3 = "./models/bge-m3" # Local BGE-M3 path
|
||||
bge_reranker = "./models/bge-reranker-base" # Local reranker path
|
||||
|
||||
[data]
|
||||
cache_dir = "./cache/data" # Dataset cache
|
||||
optimization_output_dir = "./cache/data" # Optimization default
|
||||
```
|
||||
|
||||
### Script Defaults
|
||||
```bash
|
||||
# Data optimization (processed datasets)
|
||||
python -m data.optimization bge-m3 data/train.jsonl --output_dir ./cache/data
|
||||
|
||||
# Model training (output models)
|
||||
python scripts/train_m3.py --output_dir ./output/bge-m3 --cache_dir ./cache/models
|
||||
|
||||
# Model evaluation (results)
|
||||
python scripts/evaluate.py --output_dir ./evaluation_results
|
||||
```
|
||||
|
||||
## 🔄 Data Flow
|
||||
|
||||
```
|
||||
Raw Data → Preprocessing → Optimization → Training → Output
|
||||
data/ → - → cache/data/ → models → output/
|
||||
↓ ↓ ↓
|
||||
Input Pre-tokenized Model Fine-tuned
|
||||
JSONL → .pkl files → Loading → Models
|
||||
```
|
||||
|
||||
## 🛠️ Usage Examples
|
||||
|
||||
### Model Cache (Download Cache)
|
||||
```python
|
||||
# Automatic model caching
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
"BAAI/bge-m3",
|
||||
cache_dir="./cache/models" # Downloads cached here
|
||||
)
|
||||
```
|
||||
|
||||
### Data Cache (Optimization Cache)
|
||||
```bash
|
||||
# Create optimized datasets
|
||||
python -m data.optimization bge-m3 data/train.jsonl \
|
||||
--output_dir ./cache/data # .pkl files created here
|
||||
|
||||
# Use cached datasets in training
|
||||
python scripts/train_m3.py \
|
||||
--train_data ./cache/data/train_m3_cached.pkl
|
||||
```
|
||||
|
||||
### Local Models (No Downloads)
|
||||
```python
|
||||
# Use local model (no downloads)
|
||||
model = BGEM3Model.from_pretrained("./models/bge-m3")
|
||||
```
|
||||
|
||||
## 📋 Best Practices
|
||||
|
||||
1. **Separate Concerns**: Keep model cache, data cache, and outputs separate
|
||||
2. **Clear Naming**: Use descriptive suffixes (`.pkl` for cache, `_cached` for optimized)
|
||||
3. **Consistent Paths**: Use config.toml for centralized path management
|
||||
4. **Cache Management**:
|
||||
- Model cache: Can be shared across projects
|
||||
- Data cache: Project-specific optimized datasets
|
||||
- Output: Training results and checkpoints
|
||||
|
||||
## 🧹 Cleanup Commands
|
||||
|
||||
```bash
|
||||
# Clean data cache (re-optimization needed)
|
||||
rm -rf ./cache/data/
|
||||
|
||||
# Clean model cache (re-download needed)
|
||||
rm -rf ./cache/models/
|
||||
|
||||
# Clean training outputs
|
||||
rm -rf ./output/
|
||||
|
||||
# Clean all caches
|
||||
rm -rf ./cache/
|
||||
```
|
||||
|
||||
This structure ensures clear separation of concerns and efficient caching for both model downloads and processed datasets.
|
||||
298
docs/usage_guide.md
Normal file
298
docs/usage_guide.md
Normal file
@@ -0,0 +1,298 @@
|
||||
# BGE Fine-tuning Usage Guide
|
||||
|
||||
## Overview
|
||||
|
||||
This guide provides detailed instructions on fine-tuning BGE (BAAI General Embedding) models using our enhanced training framework. The framework supports both BGE-M3 (embedding) and BGE-reranker (cross-encoder) models with state-of-the-art training techniques.
|
||||
|
||||
## Table of Contents
|
||||
|
||||
1. [Installation](#installation)
|
||||
2. [Quick Start](#quick-start)
|
||||
3. [Data Preparation](#data-preparation)
|
||||
4. [Training Scripts](#training-scripts)
|
||||
5. [Advanced Usage](#advanced-usage)
|
||||
6. [Troubleshooting](#troubleshooting)
|
||||
|
||||
## Installation
|
||||
|
||||
```bash
|
||||
# Clone the repository
|
||||
git clone https://github.com/yourusername/bge-finetune.git
|
||||
cd bge-finetune
|
||||
|
||||
# Install dependencies
|
||||
pip install -r requirements.txt
|
||||
|
||||
# Optional: Install Ascend NPU support
|
||||
# pip install torch-npu # If using Huawei Ascend NPUs
|
||||
```
|
||||
|
||||
## Quick Start
|
||||
|
||||
### 1. Prepare Your Data
|
||||
|
||||
The framework supports multiple data formats:
|
||||
- **JSONL** (recommended): One JSON object per line
|
||||
- **CSV/TSV**: Tabular data with headers
|
||||
- **JSON**: Single JSON file with array of samples
|
||||
|
||||
Example JSONL format for embedding model:
|
||||
```json
|
||||
{"query": "What is machine learning?", "pos": ["ML is a subset of AI..."], "neg": ["The weather today..."]}
|
||||
{"query": "How to cook pasta?", "pos": ["Boil water and add pasta..."], "neg": ["Machine learning uses..."]}
|
||||
```
|
||||
|
||||
### 2. Fine-tune BGE-M3 (Embedding Model)
|
||||
|
||||
```bash
|
||||
python scripts/train_m3.py \
|
||||
--model_name_or_path BAAI/bge-m3 \
|
||||
--train_data data/train.jsonl \
|
||||
--output_dir ./output/bge-m3-finetuned \
|
||||
--num_train_epochs 3 \
|
||||
--per_device_train_batch_size 16 \
|
||||
--learning_rate 1e-5
|
||||
```
|
||||
|
||||
### 3. Fine-tune BGE-Reranker
|
||||
|
||||
```bash
|
||||
python scripts/train_reranker.py \
|
||||
--model_name_or_path BAAI/bge-reranker-base \
|
||||
--train_data data/train_pairs.jsonl \
|
||||
--output_dir ./output/bge-reranker-finetuned \
|
||||
--num_train_epochs 3 \
|
||||
--per_device_train_batch_size 32 \
|
||||
--learning_rate 2e-5
|
||||
```
|
||||
|
||||
## Data Preparation
|
||||
|
||||
### Data Formats
|
||||
|
||||
#### 1. Embedding Model (Triplets Format)
|
||||
|
||||
```json
|
||||
{
|
||||
"query": "query text",
|
||||
"pos": ["positive passage 1", "positive passage 2"],
|
||||
"neg": ["negative passage 1", "negative passage 2"],
|
||||
"pos_scores": [1.0, 0.9], // Optional: relevance scores
|
||||
"neg_scores": [0.1, 0.2] // Optional: relevance scores
|
||||
}
|
||||
```
|
||||
|
||||
#### 2. Reranker Model (Pairs Format)
|
||||
|
||||
```json
|
||||
{
|
||||
"query": "query text",
|
||||
"passage": "passage text",
|
||||
"label": 1, // 1 for relevant, 0 for not relevant
|
||||
"score": 0.95 // Optional: relevance score
|
||||
}
|
||||
```
|
||||
|
||||
### Performance Optimization
|
||||
|
||||
For small datasets (< 100k samples), enable in-memory caching:
|
||||
|
||||
```python
|
||||
# In your custom script
|
||||
dataset = BGEM3Dataset(
|
||||
data_path="train.jsonl",
|
||||
tokenizer=tokenizer,
|
||||
cache_in_memory=True, # Pre-tokenize and cache in memory
|
||||
max_cache_size=100000 # Maximum samples to cache
|
||||
)
|
||||
```
|
||||
|
||||
## Training Scripts
|
||||
|
||||
### BGE-M3 Training Options
|
||||
|
||||
```bash
|
||||
python scripts/train_m3.py \
|
||||
--model_name_or_path BAAI/bge-m3 \
|
||||
--train_data data/train.jsonl \
|
||||
--eval_data data/eval.jsonl \
|
||||
--output_dir ./output/bge-m3-finetuned \
|
||||
--num_train_epochs 3 \
|
||||
--per_device_train_batch_size 16 \
|
||||
--per_device_eval_batch_size 32 \
|
||||
--learning_rate 1e-5 \
|
||||
--warmup_ratio 0.1 \
|
||||
--gradient_accumulation_steps 2 \
|
||||
--fp16 \ # Enable mixed precision training
|
||||
--train_group_size 8 \ # Number of passages per query
|
||||
--use_hard_negatives \ # Enable hard negative mining
|
||||
--temperature 0.02 \ # Contrastive loss temperature
|
||||
--save_steps 500 \
|
||||
--eval_steps 500 \
|
||||
--logging_steps 100
|
||||
```
|
||||
|
||||
### BGE-Reranker Training Options
|
||||
|
||||
```bash
|
||||
python scripts/train_reranker.py \
|
||||
--model_name_or_path BAAI/bge-reranker-base \
|
||||
--train_data data/train_pairs.jsonl \
|
||||
--eval_data data/eval_pairs.jsonl \
|
||||
--output_dir ./output/bge-reranker-finetuned \
|
||||
--num_train_epochs 3 \
|
||||
--per_device_train_batch_size 32 \
|
||||
--learning_rate 2e-5 \
|
||||
--max_length 512 \
|
||||
--train_group_size 16 \ # Pairs per query
|
||||
--gradient_checkpointing \ # Save memory
|
||||
--logging_steps 50
|
||||
```
|
||||
|
||||
### Joint Training (RocketQAv2 Approach)
|
||||
|
||||
Train retriever and reranker together:
|
||||
|
||||
```bash
|
||||
python scripts/train_joint.py \
|
||||
--retriever_model BAAI/bge-m3 \
|
||||
--reranker_model BAAI/bge-reranker-base \
|
||||
--train_data data/train.jsonl \
|
||||
--output_dir ./output/joint-training \
|
||||
--num_train_epochs 3 \
|
||||
--alternate_steps 100 # Switch between models every N steps
|
||||
```
|
||||
|
||||
## Advanced Usage
|
||||
|
||||
### Configuration Files
|
||||
|
||||
Use TOML configuration for complex setups:
|
||||
|
||||
```toml
|
||||
# config.toml
|
||||
[training]
|
||||
default_batch_size = 16
|
||||
default_num_epochs = 3
|
||||
gradient_accumulation_steps = 2
|
||||
|
||||
[m3]
|
||||
model_name_or_path = "BAAI/bge-m3"
|
||||
train_group_size = 8
|
||||
temperature = 0.02
|
||||
|
||||
[hardware]
|
||||
device_type = "cuda" # or "npu" for Ascend
|
||||
```
|
||||
|
||||
Run with config:
|
||||
```bash
|
||||
python scripts/train_m3.py --config_path config.toml --train_data data/train.jsonl
|
||||
```
|
||||
|
||||
### Multi-GPU Training
|
||||
|
||||
```bash
|
||||
# Single node, multiple GPUs
|
||||
torchrun --nproc_per_node=4 scripts/train_m3.py \
|
||||
--train_data data/train.jsonl \
|
||||
--output_dir ./output/bge-m3-multigpu
|
||||
|
||||
# Multiple nodes
|
||||
torchrun --nnodes=2 --nproc_per_node=4 \
|
||||
--rdzv_id=100 --rdzv_backend=c10d \
|
||||
--rdzv_endpoint=$MASTER_ADDR:$MASTER_PORT \
|
||||
scripts/train_m3.py --train_data data/train.jsonl
|
||||
```
|
||||
|
||||
### Custom Data Preprocessing
|
||||
|
||||
```python
|
||||
from data.preprocessing import DataPreprocessor
|
||||
|
||||
# Preprocess custom format data
|
||||
preprocessor = DataPreprocessor(
|
||||
tokenizer=tokenizer,
|
||||
max_query_length=64,
|
||||
max_passage_length=512
|
||||
)
|
||||
|
||||
# Convert and clean data
|
||||
output_path, val_path = preprocessor.preprocess_file(
|
||||
input_path="raw_data.csv",
|
||||
output_path="processed_data.jsonl",
|
||||
file_format="csv",
|
||||
validation_split=0.1
|
||||
)
|
||||
```
|
||||
|
||||
### Model Evaluation
|
||||
|
||||
```bash
|
||||
python scripts/evaluate.py \
|
||||
--retriever_model ./output/bge-m3-finetuned \
|
||||
--reranker_model ./output/bge-reranker-finetuned \
|
||||
--eval_data data/test.jsonl \
|
||||
--output_dir ./evaluation_results \
|
||||
--metrics ndcg mrr recall \
|
||||
--top_k 10 20 50
|
||||
```
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Common Issues
|
||||
|
||||
1. **Out of Memory**
|
||||
- Reduce batch size
|
||||
- Enable gradient checkpointing: `--gradient_checkpointing`
|
||||
- Use fp16 training: `--fp16`
|
||||
- Enable in-memory caching for small datasets
|
||||
|
||||
2. **Slow Training**
|
||||
- Increase number of data loader workers: `--dataloader_num_workers 8`
|
||||
- Enable in-memory caching for datasets < 100k samples
|
||||
- Use SSD for data storage
|
||||
|
||||
3. **Poor Performance**
|
||||
- Check data quality and format
|
||||
- Adjust learning rate and warmup
|
||||
- Use hard negative mining: `--use_hard_negatives`
|
||||
- Increase training epochs
|
||||
|
||||
### Debugging Tips
|
||||
|
||||
```bash
|
||||
# Enable debug logging
|
||||
export LOG_LEVEL=DEBUG
|
||||
python scripts/train_m3.py ...
|
||||
|
||||
# Profile training
|
||||
python scripts/benchmark.py \
|
||||
--model_type retriever \
|
||||
--model_path ./output/bge-m3-finetuned \
|
||||
--data_path data/test.jsonl
|
||||
```
|
||||
|
||||
## Best Practices
|
||||
|
||||
1. **Data Quality**
|
||||
- Ensure balanced positive/negative samples
|
||||
- Use relevance scores when available
|
||||
- Clean and normalize text data
|
||||
|
||||
2. **Training Strategy**
|
||||
- Start with small learning rate (1e-5 to 2e-5)
|
||||
- Use warmup (10% of steps)
|
||||
- Monitor evaluation metrics
|
||||
|
||||
3. **Resource Optimization**
|
||||
- Use gradient accumulation for large batches
|
||||
- Enable mixed precision training
|
||||
- Consider in-memory caching for small datasets
|
||||
|
||||
## Additional Resources
|
||||
|
||||
- [Data Format Examples](data_formats.md)
|
||||
- [Model Architecture Details](../models/README.md)
|
||||
- [Evaluation Metrics Guide](../evaluation/README.md)
|
||||
- [API Reference](api_reference.md)
|
||||
432
docs/validation_guide.md
Normal file
432
docs/validation_guide.md
Normal file
@@ -0,0 +1,432 @@
|
||||
# BGE Fine-tuning Validation Guide
|
||||
|
||||
This guide provides comprehensive instructions for validating your fine-tuned BGE models to ensure they actually perform better than the baseline models.
|
||||
|
||||
## 🎯 Overview
|
||||
|
||||
After fine-tuning BGE models, it's crucial to validate that your models have actually improved. This guide covers multiple validation approaches:
|
||||
|
||||
1. **Quick Validation** - Fast sanity checks
|
||||
2. **Comprehensive Validation** - Detailed performance analysis
|
||||
3. **Comparison Benchmarks** - Head-to-head baseline comparisons
|
||||
4. **Pipeline Validation** - End-to-end retrieval + reranking tests
|
||||
5. **Statistical Analysis** - Significance testing and detailed metrics
|
||||
|
||||
## 🚀 Quick Start
|
||||
|
||||
### Option 1: Run Full Validation Suite (Recommended)
|
||||
|
||||
The easiest way to validate your models is using the validation suite:
|
||||
|
||||
```bash
|
||||
# Auto-discover trained models and run complete validation
|
||||
python scripts/run_validation_suite.py --auto-discover
|
||||
|
||||
# Or specify models explicitly
|
||||
python scripts/run_validation_suite.py \
|
||||
--retriever_model ./output/bge-m3-enhanced/final_model \
|
||||
--reranker_model ./output/bge-reranker/final_model
|
||||
```
|
||||
|
||||
This runs all validation tests and generates a comprehensive report.
|
||||
|
||||
### Option 2: Quick Validation Only
|
||||
|
||||
For a fast check if your models are working correctly:
|
||||
|
||||
```bash
|
||||
python scripts/quick_validation.py \
|
||||
--retriever_model ./output/bge-m3-enhanced/final_model \
|
||||
--reranker_model ./output/bge-reranker/final_model \
|
||||
--test_pipeline
|
||||
```
|
||||
|
||||
## 📊 Validation Tools
|
||||
|
||||
### 1. Quick Validation (`quick_validation.py`)
|
||||
|
||||
**Purpose**: Fast sanity checks to verify models are working and show basic improvements.
|
||||
|
||||
**Features**:
|
||||
- Tests model loading and basic functionality
|
||||
- Uses predefined test queries in multiple languages
|
||||
- Compares relevance scoring between baseline and fine-tuned models
|
||||
- Tests pipeline performance (retrieval + reranking)
|
||||
|
||||
**Usage**:
|
||||
```bash
|
||||
# Test both models
|
||||
python scripts/quick_validation.py \
|
||||
--retriever_model ./output/bge-m3-enhanced/final_model \
|
||||
--reranker_model ./output/bge-reranker/final_model \
|
||||
--test_pipeline
|
||||
|
||||
# Test only retriever
|
||||
python scripts/quick_validation.py \
|
||||
--retriever_model ./output/bge-m3-enhanced/final_model
|
||||
|
||||
# Save results to file
|
||||
python scripts/quick_validation.py \
|
||||
--retriever_model ./output/bge-m3-enhanced/final_model \
|
||||
--output_file ./results.json
|
||||
```
|
||||
|
||||
**Output**: Console summary + optional JSON results file
|
||||
|
||||
### 2. Comprehensive Validation (`comprehensive_validation.py`)
|
||||
|
||||
**Purpose**: Detailed validation across multiple datasets with statistical analysis.
|
||||
|
||||
**Features**:
|
||||
- Tests on all available datasets automatically
|
||||
- Computes comprehensive metrics (Recall@K, Precision@K, MAP, MRR, NDCG)
|
||||
- Statistical significance testing
|
||||
- Performance degradation detection
|
||||
- Generates HTML reports with visualizations
|
||||
- Cross-dataset performance analysis
|
||||
|
||||
**Usage**:
|
||||
```bash
|
||||
# Comprehensive validation with auto-discovered datasets
|
||||
python scripts/comprehensive_validation.py \
|
||||
--retriever_finetuned ./output/bge-m3-enhanced/final_model \
|
||||
--reranker_finetuned ./output/bge-reranker/final_model
|
||||
|
||||
# Use specific datasets
|
||||
python scripts/comprehensive_validation.py \
|
||||
--retriever_finetuned ./output/bge-m3-enhanced/final_model \
|
||||
--test_datasets data/datasets/examples/embedding_data.jsonl \
|
||||
data/datasets/Reranker_AFQMC/dev.json \
|
||||
--output_dir ./validation_results
|
||||
|
||||
# Custom evaluation settings
|
||||
python scripts/comprehensive_validation.py \
|
||||
--retriever_finetuned ./output/bge-m3-enhanced/final_model \
|
||||
--batch_size 16 \
|
||||
--k_values 1 3 5 10 20 \
|
||||
--retrieval_top_k 100 \
|
||||
--rerank_top_k 10
|
||||
```
|
||||
|
||||
**Output**:
|
||||
- JSON results file
|
||||
- HTML report with detailed analysis
|
||||
- Performance comparison tables
|
||||
- Statistical significance analysis
|
||||
|
||||
### 3. Model Comparison Scripts
|
||||
|
||||
**Purpose**: Head-to-head performance comparisons with baseline models.
|
||||
|
||||
#### Retriever Comparison (`compare_retriever.py`)
|
||||
|
||||
```bash
|
||||
python scripts/compare_retriever.py \
|
||||
--finetuned_model_path ./output/bge-m3-enhanced/final_model \
|
||||
--baseline_model_path BAAI/bge-m3 \
|
||||
--data_path data/datasets/examples/embedding_data.jsonl \
|
||||
--batch_size 16 \
|
||||
--max_samples 1000 \
|
||||
--output ./retriever_comparison.txt
|
||||
```
|
||||
|
||||
#### Reranker Comparison (`compare_reranker.py`)
|
||||
|
||||
```bash
|
||||
python scripts/compare_reranker.py \
|
||||
--finetuned_model_path ./output/bge-reranker/final_model \
|
||||
--baseline_model_path BAAI/bge-reranker-base \
|
||||
--data_path data/datasets/examples/reranker_data.jsonl \
|
||||
--batch_size 16 \
|
||||
--max_samples 1000 \
|
||||
--output ./reranker_comparison.txt
|
||||
```
|
||||
|
||||
**Output**: Tab-separated comparison tables with metrics and performance deltas.
|
||||
|
||||
### 4. Validation Suite Runner (`run_validation_suite.py`)
|
||||
|
||||
**Purpose**: Orchestrates all validation tests in a systematic way.
|
||||
|
||||
**Features**:
|
||||
- Auto-discovers trained models
|
||||
- Runs multiple validation approaches
|
||||
- Generates comprehensive reports
|
||||
- Provides final verdict on model quality
|
||||
|
||||
**Usage**:
|
||||
```bash
|
||||
# Full auto-discovery and validation
|
||||
python scripts/run_validation_suite.py --auto-discover
|
||||
|
||||
# Specify models and validation types
|
||||
python scripts/run_validation_suite.py \
|
||||
--retriever_model ./output/bge-m3-enhanced/final_model \
|
||||
--reranker_model ./output/bge-reranker/final_model \
|
||||
--validation_types quick comprehensive comparison
|
||||
|
||||
# Custom settings
|
||||
python scripts/run_validation_suite.py \
|
||||
--retriever_model ./output/bge-m3-enhanced/final_model \
|
||||
--batch_size 32 \
|
||||
--max_samples 500 \
|
||||
--output_dir ./my_validation_results
|
||||
```
|
||||
|
||||
**Output**:
|
||||
- Multiple result files from each validation type
|
||||
- HTML summary report
|
||||
- Final verdict and recommendations
|
||||
|
||||
### 5. Validation Utilities (`validation_utils.py`)
|
||||
|
||||
**Purpose**: Helper utilities for validation preparation and analysis.
|
||||
|
||||
**Features**:
|
||||
- Discover trained models in workspace
|
||||
- Analyze available test datasets
|
||||
- Create sample test data
|
||||
- Analyze validation results
|
||||
|
||||
**Usage**:
|
||||
```bash
|
||||
# Discover trained models
|
||||
python scripts/validation_utils.py --discover-models
|
||||
|
||||
# Analyze available test datasets
|
||||
python scripts/validation_utils.py --analyze-datasets
|
||||
|
||||
# Create sample test data
|
||||
python scripts/validation_utils.py --create-sample-data --output-dir ./test_data
|
||||
|
||||
# Analyze validation results
|
||||
python scripts/validation_utils.py --analyze-results ./validation_results
|
||||
```
|
||||
|
||||
## 📁 Understanding Results
|
||||
|
||||
### Result Files
|
||||
|
||||
After running validation, you'll get several result files:
|
||||
|
||||
1. **`validation_suite_report.html`** - Main HTML report with visual summary
|
||||
2. **`validation_results.json`** - Complete detailed results in JSON format
|
||||
3. **`quick_validation_results.json`** - Quick validation results
|
||||
4. **`*_comparison_*.txt`** - Model comparison tables
|
||||
5. **Comprehensive validation folder** with detailed metrics
|
||||
|
||||
### Interpreting Results
|
||||
|
||||
#### Overall Verdict
|
||||
|
||||
- **🌟 EXCELLENT** - Significant improvements across most metrics
|
||||
- **✅ GOOD** - Clear improvements with minor degradations
|
||||
- **👌 FAIR** - Mixed results, some improvements
|
||||
- **⚠️ PARTIAL** - Some validation tests failed
|
||||
- **❌ POOR** - More degradations than improvements
|
||||
|
||||
#### Key Metrics to Watch
|
||||
|
||||
**For Retrievers**:
|
||||
- **Recall@K** - How many relevant documents are in top-K results
|
||||
- **MAP (Mean Average Precision)** - Overall ranking quality
|
||||
- **MRR (Mean Reciprocal Rank)** - Position of first relevant result
|
||||
- **NDCG@K** - Normalized ranking quality with graded relevance
|
||||
|
||||
**For Rerankers**:
|
||||
- **Accuracy** - Correct classification of relevant/irrelevant pairs
|
||||
- **MRR@K** - Mean reciprocal rank in top-K results
|
||||
- **NDCG@K** - Ranking quality with position discounting
|
||||
|
||||
**For Pipeline**:
|
||||
- **End-to-End Recall@K** - Overall system performance
|
||||
- **Pipeline Accuracy** - Correct top results after retrieval + reranking
|
||||
|
||||
#### Statistical Significance
|
||||
|
||||
The comprehensive validation includes statistical tests to determine if improvements are significant:
|
||||
|
||||
- **Improvement %** - Percentage improvement over baseline
|
||||
- **Statistical Significance** - Whether improvements are statistically meaningful
|
||||
- **Effect Size** - Magnitude of the improvement
|
||||
|
||||
## 🔧 Troubleshooting
|
||||
|
||||
### Common Issues
|
||||
|
||||
#### 1. Models Not Found
|
||||
```
|
||||
❌ Retriever model not found: ./output/bge-m3-enhanced/final_model
|
||||
```
|
||||
**Solution**:
|
||||
- Check that training completed successfully
|
||||
- Verify the model path exists
|
||||
- Use `--discover-models` to find available models
|
||||
|
||||
#### 2. No Test Datasets Found
|
||||
```
|
||||
❌ No test datasets found!
|
||||
```
|
||||
**Solution**:
|
||||
- Place test data in `data/datasets/` directory
|
||||
- Use `--create-sample-data` to generate sample datasets
|
||||
- Specify datasets explicitly with `--test_datasets`
|
||||
|
||||
#### 3. Validation Failures
|
||||
```
|
||||
❌ Comprehensive validation failed
|
||||
```
|
||||
**Solution**:
|
||||
- Check model compatibility (tokenizer issues)
|
||||
- Reduce batch size if running out of memory
|
||||
- Verify dataset format is correct
|
||||
- Check logs for detailed error messages
|
||||
|
||||
#### 4. Poor Performance Results
|
||||
```
|
||||
⚠️ More degradations than improvements
|
||||
```
|
||||
**Analysis Steps**:
|
||||
1. Check training data quality and format
|
||||
2. Verify hyperparameters (learning rate might be too high)
|
||||
3. Ensure sufficient training data and epochs
|
||||
4. Check for data leakage or format mismatches
|
||||
5. Consider different base models
|
||||
|
||||
### Validation Best Practices
|
||||
|
||||
#### 1. Before Running Validation
|
||||
- ✅ Ensure training completed without errors
|
||||
- ✅ Verify model files exist and are complete
|
||||
- ✅ Check that test datasets are in correct format
|
||||
- ✅ Have baseline models available for comparison
|
||||
|
||||
#### 2. During Validation
|
||||
- 📊 Start with quick validation for fast feedback
|
||||
- 🔬 Run comprehensive validation on representative datasets
|
||||
- 📈 Use multiple metrics to get complete picture
|
||||
- ⏱️ Be patient - comprehensive validation takes time
|
||||
|
||||
#### 3. Analyzing Results
|
||||
- 🎯 Focus on metrics most relevant to your use case
|
||||
- 📊 Look at trends across multiple datasets
|
||||
- ⚖️ Consider statistical significance, not just raw improvements
|
||||
- 🔍 Investigate datasets where performance degraded
|
||||
|
||||
## 🎯 Validation Checklist
|
||||
|
||||
Use this checklist to ensure thorough validation:
|
||||
|
||||
### Pre-Validation Setup
|
||||
- [ ] Models trained and saved successfully
|
||||
- [ ] Test datasets available and formatted correctly
|
||||
- [ ] Baseline models accessible (BAAI/bge-m3, BAAI/bge-reranker-base)
|
||||
- [ ] Sufficient compute resources for validation
|
||||
|
||||
### Quick Validation (5-10 minutes)
|
||||
- [ ] Run quick validation script
|
||||
- [ ] Verify models load correctly
|
||||
- [ ] Check basic performance improvements
|
||||
- [ ] Test pipeline functionality (if both models available)
|
||||
|
||||
### Comprehensive Validation (30-60 minutes)
|
||||
- [ ] Run comprehensive validation script
|
||||
- [ ] Test on multiple datasets
|
||||
- [ ] Analyze statistical significance
|
||||
- [ ] Review HTML report for detailed insights
|
||||
|
||||
### Comparison Benchmarks (15-30 minutes)
|
||||
- [ ] Run retriever comparison (if applicable)
|
||||
- [ ] Run reranker comparison (if applicable)
|
||||
- [ ] Analyze performance deltas
|
||||
- [ ] Check throughput and memory usage
|
||||
|
||||
### Final Analysis
|
||||
- [ ] Review overall verdict and recommendations
|
||||
- [ ] Identify best and worst performing datasets
|
||||
- [ ] Check for any critical issues or degradations
|
||||
- [ ] Document findings and next steps
|
||||
|
||||
## 🚀 Advanced Usage
|
||||
|
||||
### Custom Test Datasets
|
||||
|
||||
Create your own test datasets for domain-specific validation:
|
||||
|
||||
```jsonl
|
||||
// For retriever testing (embedding_format)
|
||||
{"query": "Your query", "pos": ["relevant doc 1", "relevant doc 2"], "neg": ["irrelevant doc 1"]}
|
||||
|
||||
// For reranker testing (pairs format)
|
||||
{"query": "Your query", "passage": "Document text", "label": 1}
|
||||
{"query": "Your query", "passage": "Irrelevant text", "label": 0}
|
||||
```
|
||||
|
||||
### Cross-Domain Validation
|
||||
|
||||
Test model generalization across different domains:
|
||||
|
||||
```bash
|
||||
python scripts/comprehensive_validation.py \
|
||||
--retriever_finetuned ./output/bge-m3-enhanced/final_model \
|
||||
--test_datasets data/medical_qa.jsonl data/legal_docs.jsonl data/tech_support.jsonl
|
||||
```
|
||||
|
||||
### Performance Profiling
|
||||
|
||||
Monitor resource usage during validation:
|
||||
|
||||
```bash
|
||||
python scripts/comprehensive_validation.py \
|
||||
--retriever_finetuned ./output/bge-m3-enhanced/final_model \
|
||||
--batch_size 8 \ # Reduce if memory issues
|
||||
--max_samples 100 # Limit samples for speed
|
||||
```
|
||||
|
||||
### A/B Testing Multiple Models
|
||||
|
||||
Compare different fine-tuning approaches:
|
||||
|
||||
```bash
|
||||
# Test Model A
|
||||
python scripts/comprehensive_validation.py \
|
||||
--retriever_finetuned ./output/model_A/final_model \
|
||||
--output_dir ./validation_A
|
||||
|
||||
# Test Model B
|
||||
python scripts/comprehensive_validation.py \
|
||||
--retriever_finetuned ./output/model_B/final_model \
|
||||
--output_dir ./validation_B
|
||||
|
||||
# Compare results
|
||||
python scripts/validation_utils.py --analyze-results ./validation_A
|
||||
python scripts/validation_utils.py --analyze-results ./validation_B
|
||||
```
|
||||
|
||||
## 📚 Additional Resources
|
||||
|
||||
- **BGE Model Documentation**: [BAAI/bge](https://github.com/FlagOpen/FlagEmbedding)
|
||||
- **Evaluation Metrics Guide**: See `evaluation/metrics.py` for detailed metric implementations
|
||||
- **Dataset Format Guide**: See `data/datasets/examples/format_usage_examples.py`
|
||||
- **Training Guide**: See main README for training instructions
|
||||
|
||||
## 💡 Tips for Better Validation
|
||||
|
||||
1. **Use Multiple Datasets**: Don't rely on a single test set
|
||||
2. **Check Statistical Significance**: Small improvements might not be meaningful
|
||||
3. **Monitor Resource Usage**: Ensure models are efficient enough for production
|
||||
4. **Validate Domain Generalization**: Test on data similar to production use case
|
||||
5. **Document Results**: Keep validation reports for future reference
|
||||
6. **Iterate Based on Results**: Use validation insights to improve training
|
||||
|
||||
## 🤝 Getting Help
|
||||
|
||||
If you encounter issues with validation:
|
||||
|
||||
1. Check the troubleshooting section above
|
||||
2. Review the logs for detailed error messages
|
||||
3. Use `--discover-models` and `--analyze-datasets` to debug setup issues
|
||||
4. Start with quick validation to isolate problems
|
||||
5. Check model and dataset formats are correct
|
||||
|
||||
Remember: Good validation is key to successful fine-tuning. Take time to thoroughly test your models before deploying them!
|
||||
Reference in New Issue
Block a user