155 lines
4.6 KiB
Markdown
155 lines
4.6 KiB
Markdown
# BGE Fine-tuning for Chinese Book Retrieval
|
||
|
||
Robust, production-ready tooling to fine-tune the BAAI **BGE-M3** embedding model and **BGE-Reranker** cross-encoder for high-performance retrieval of Chinese book content (with English query support).
|
||
|
||
---
|
||
|
||
## Table of Contents
|
||
1. Overview
|
||
2. Requirements & Compatibility
|
||
3. Core Features
|
||
4. Installation
|
||
5. Quick Start
|
||
6. Configuration System
|
||
7. Training & Evaluation
|
||
8. Distributed & Ascend NPU Support
|
||
9. Troubleshooting & FAQ
|
||
10. References
|
||
11. Roadmap
|
||
|
||
---
|
||
|
||
## 1. Overview
|
||
* Unified repository for dense, sparse (SPLADE-style) and multi-vector (ColBERT-style) retrieval.
|
||
* Joint training (RocketQAv2) and ANCE-style hard-negative mining included.
|
||
* Built-in evaluation, benchmarking and statistical significance testing.
|
||
|
||
---
|
||
|
||
## 2. Requirements & Compatibility
|
||
* **Python ≥ 3.10**
|
||
* **PyTorch ≥ 2.2.0** (nightly recommended for CUDA 12.8+)
|
||
* **CUDA ≥ 12.8** *(RTX 5090 / A100 / H100)* or **Torch-NPU** *(Ascend 910 / 910B)*
|
||
|
||
See `requirements.txt` for the full dependency list.
|
||
|
||
---
|
||
|
||
## 3. Core Features
|
||
* **InfoNCE (τ = 0.02) & Listwise CE** losses with self-distillation.
|
||
* **ANCE-style Hard-Negative Mining** with FAISS index refresh.
|
||
* **Atomic Checkpointing** & robust failure recovery.
|
||
* **Automatic Mixed Precision** (fp16 / bf16) & gradient checkpointing.
|
||
* **Comprehensive Metrics**: Recall@k, MRR@k, NDCG@k, MAP, Success-Rate.
|
||
* **Cross-hardware Support**: CUDA (NCCL), Ascend NPU (HCCL), CPU (Gloo).
|
||
|
||
---
|
||
|
||
## 4. Installation
|
||
```bash
|
||
# 1. Install PyTorch nightly (CUDA 12.8)
|
||
pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu128
|
||
|
||
# 2. Install project dependencies
|
||
pip install -r requirements.txt
|
||
|
||
# 3. (Optional) Ascend NPU backend
|
||
pip install torch_npu # then edit config.toml [ascend]
|
||
```
|
||
Models auto-download on first use, or fetch manually:
|
||
```bash
|
||
huggingface-cli download BAAI/bge-m3 --local-dir ./models/bge-m3
|
||
huggingface-cli download BAAI/bge-reranker-base --local-dir ./models/bge-reranker-base
|
||
```
|
||
|
||
---
|
||
|
||
## 5. Quick Start
|
||
### Fine-tune BGE-M3 (Embedding)
|
||
```bash
|
||
python scripts/train_m3.py \
|
||
--train_data data/train.jsonl \
|
||
--output_dir output/bge-m3 \
|
||
--per_device_train_batch_size 8 \
|
||
--learning_rate 1e-5
|
||
```
|
||
|
||
### Fine-tune BGE-Reranker (Cross-encoder)
|
||
```bash
|
||
python scripts/train_reranker.py \
|
||
--train_data data/train.jsonl \
|
||
--output_dir output/bge-reranker \
|
||
--learning_rate 6e-5 \
|
||
--train_group_size 16
|
||
```
|
||
|
||
### Joint Training (RocketQAv2-style)
|
||
```bash
|
||
python scripts/train_joint.py \
|
||
--train_data data/train.jsonl \
|
||
--output_dir output/joint-training \
|
||
--bf16 --gradient_checkpointing
|
||
```
|
||
|
||
---
|
||
|
||
## 6. Configuration System
|
||
Configuration is **TOML-driven** with clear precedence:
|
||
1. Command-line → 2. Model (`[m3]` / `[reranker]`) → 3. Hardware (`[hardware]` / `[ascend]`) → 4. Training defaults (`[training]`) → 5. Dataclass field defaults.
|
||
|
||
Load & validate:
|
||
```python
|
||
from config.base_config import BaseConfig
|
||
cfg = BaseConfig.from_toml("config.toml")
|
||
```
|
||
|
||
Key sections: `[training]`, `[hardware]`, `[distributed]`, `[m3]`, `[reranker]`, `[augmentation]`, `[checkpoint]`, `[logging]`.
|
||
|
||
---
|
||
|
||
## 7. Training & Evaluation
|
||
* **Scripts**: `train_m3.py`, `train_reranker.py`, `train_joint.py`
|
||
* **Evaluation**: `scripts/evaluate.py` supports retriever, reranker or full pipeline.
|
||
* **Benchmarking**: `scripts/benchmark.py` for latency / throughput; `compare_*` for baseline comparisons.
|
||
|
||
Example evaluation:
|
||
```bash
|
||
python scripts/evaluate.py \
|
||
--retriever_model output/bge-m3 \
|
||
--eval_data data/test.jsonl \
|
||
--k_values 1 5 10
|
||
```
|
||
|
||
---
|
||
|
||
## 8. Distributed & Ascend NPU Support
|
||
### Multi-GPU (NVIDIA)
|
||
```bash
|
||
torchrun --nproc_per_node=4 scripts/train_m3.py --train_data data/train.jsonl --output_dir output/distributed
|
||
```
|
||
Backend auto-selects **NCCL**. Use `[distributed]` in `config.toml` for fine-tuning.
|
||
|
||
### Huawei Ascend NPU
|
||
```bash
|
||
pip install torch_npu
|
||
python scripts/train_m3.py --device npu --train_data data/train.jsonl --output_dir output/ascend
|
||
```
|
||
Backend auto-selects **HCCL**.
|
||
|
||
---
|
||
|
||
## 9. Troubleshooting & FAQ
|
||
* **CUDA mismatch**: install the matching PyTorch + CUDA build (`nvidia-smi` to verify).
|
||
* **Ascend setup**: ensure `torch_npu` and `[ascend]` config.
|
||
* **NCCL errors**: set `NCCL_DEBUG=INFO` and disable IB if needed (`NCCL_IB_DISABLE=1`).
|
||
* **Memory issues**: enable `--gradient_checkpointing` & mixed precision.
|
||
|
||
---
|
||
|
||
## 10. References
|
||
1. BGE-M3 Embeddings – [arXiv:2402.03216](https://arxiv.org/abs/2402.03216)
|
||
2. RocketQAv2 – ACL 2021
|
||
3. ANCE – [arXiv:2007.00808](https://arxiv.org/abs/2007.00808)
|
||
4. ColBERT – SIGIR 2020
|
||
5. SPLADE – [arXiv:2109.10086](https://arxiv.org/abs/2109.10086)
|