4.6 KiB
4.6 KiB
BGE Fine-tuning for Chinese Book Retrieval
Robust, production-ready tooling to fine-tune the BAAI BGE-M3 embedding model and BGE-Reranker cross-encoder for high-performance retrieval of Chinese book content (with English query support).
Table of Contents
- Overview
- Requirements & Compatibility
- Core Features
- Installation
- Quick Start
- Configuration System
- Training & Evaluation
- Distributed & Ascend NPU Support
- Troubleshooting & FAQ
- References
- Roadmap
1. Overview
- Unified repository for dense, sparse (SPLADE-style) and multi-vector (ColBERT-style) retrieval.
- Joint training (RocketQAv2) and ANCE-style hard-negative mining included.
- Built-in evaluation, benchmarking and statistical significance testing.
2. Requirements & Compatibility
- Python ≥ 3.10
- PyTorch ≥ 2.2.0 (nightly recommended for CUDA 12.8+)
- CUDA ≥ 12.8 (RTX 5090 / A100 / H100) or Torch-NPU (Ascend 910 / 910B)
See requirements.txt for the full dependency list.
3. Core Features
- InfoNCE (τ = 0.02) & Listwise CE losses with self-distillation.
- ANCE-style Hard-Negative Mining with FAISS index refresh.
- Atomic Checkpointing & robust failure recovery.
- Automatic Mixed Precision (fp16 / bf16) & gradient checkpointing.
- Comprehensive Metrics: Recall@k, MRR@k, NDCG@k, MAP, Success-Rate.
- Cross-hardware Support: CUDA (NCCL), Ascend NPU (HCCL), CPU (Gloo).
4. Installation
# 1. Install PyTorch nightly (CUDA 12.8)
pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu128
# 2. Install project dependencies
pip install -r requirements.txt
# 3. (Optional) Ascend NPU backend
pip install torch_npu # then edit config.toml [ascend]
Models auto-download on first use, or fetch manually:
huggingface-cli download BAAI/bge-m3 --local-dir ./models/bge-m3
huggingface-cli download BAAI/bge-reranker-base --local-dir ./models/bge-reranker-base
5. Quick Start
Fine-tune BGE-M3 (Embedding)
python scripts/train_m3.py \
--train_data data/train.jsonl \
--output_dir output/bge-m3 \
--per_device_train_batch_size 8 \
--learning_rate 1e-5
Fine-tune BGE-Reranker (Cross-encoder)
python scripts/train_reranker.py \
--train_data data/train.jsonl \
--output_dir output/bge-reranker \
--learning_rate 6e-5 \
--train_group_size 16
Joint Training (RocketQAv2-style)
python scripts/train_joint.py \
--train_data data/train.jsonl \
--output_dir output/joint-training \
--bf16 --gradient_checkpointing
6. Configuration System
Configuration is TOML-driven with clear precedence:
- Command-line → 2. Model (
[m3]/[reranker]) → 3. Hardware ([hardware]/[ascend]) → 4. Training defaults ([training]) → 5. Dataclass field defaults.
Load & validate:
from config.base_config import BaseConfig
cfg = BaseConfig.from_toml("config.toml")
Key sections: [training], [hardware], [distributed], [m3], [reranker], [augmentation], [checkpoint], [logging].
7. Training & Evaluation
- Scripts:
train_m3.py,train_reranker.py,train_joint.py - Evaluation:
scripts/evaluate.pysupports retriever, reranker or full pipeline. - Benchmarking:
scripts/benchmark.pyfor latency / throughput;compare_*for baseline comparisons.
Example evaluation:
python scripts/evaluate.py \
--retriever_model output/bge-m3 \
--eval_data data/test.jsonl \
--k_values 1 5 10
8. Distributed & Ascend NPU Support
Multi-GPU (NVIDIA)
torchrun --nproc_per_node=4 scripts/train_m3.py --train_data data/train.jsonl --output_dir output/distributed
Backend auto-selects NCCL. Use [distributed] in config.toml for fine-tuning.
Huawei Ascend NPU
pip install torch_npu
python scripts/train_m3.py --device npu --train_data data/train.jsonl --output_dir output/ascend
Backend auto-selects HCCL.
9. Troubleshooting & FAQ
- CUDA mismatch: install the matching PyTorch + CUDA build (
nvidia-smito verify). - Ascend setup: ensure
torch_npuand[ascend]config. - NCCL errors: set
NCCL_DEBUG=INFOand disable IB if needed (NCCL_IB_DISABLE=1). - Memory issues: enable
--gradient_checkpointing& mixed precision.
10. References
- BGE-M3 Embeddings – arXiv:2402.03216
- RocketQAv2 – ACL 2021
- ANCE – arXiv:2007.00808
- ColBERT – SIGIR 2020
- SPLADE – arXiv:2109.10086