bge_finetune/scripts/train_joint.py
2025-07-22 16:55:25 +08:00

498 lines
20 KiB
Python

"""
Joint training of BGE-M3 and BGE-Reranker using RocketQAv2 approach
"""
import os
import sys
import argparse
# Set tokenizer parallelism to avoid warnings
os.environ["TOKENIZERS_PARALLELISM"] = "false"
import logging
import json
from datetime import datetime
import torch
from transformers import AutoTokenizer, set_seed
# Add parent directory to path
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from config.m3_config import BGEM3Config
from config.reranker_config import BGERerankerConfig
from models.bge_m3 import BGEM3Model
from models.bge_reranker import BGERerankerModel
from data.dataset import BGEM3Dataset, BGERerankerDataset, JointDataset
from data.preprocessing import DataPreprocessor
from training.joint_trainer import JointTrainer
from utils.logging import setup_logging, get_logger
from utils.distributed import setup_distributed, is_main_process, set_random_seed
logger = logging.getLogger(__name__)
def parse_args():
parser = argparse.ArgumentParser(description="Joint training of BGE-M3 and BGE-Reranker")
# Model arguments
parser.add_argument("--retriever_model", type=str, default="BAAI/bge-m3",
help="Path to retriever model")
parser.add_argument("--reranker_model", type=str, default="BAAI/bge-reranker-base",
help="Path to reranker model")
parser.add_argument("--cache_dir", type=str, default="./cache/data",
help="Directory to cache downloaded models")
# Data arguments
parser.add_argument("--train_data", type=str, required=True,
help="Path to training data file")
parser.add_argument("--eval_data", type=str, default=None,
help="Path to evaluation data file")
parser.add_argument("--data_format", type=str, default="auto",
choices=["auto", "jsonl", "json", "csv", "tsv", "dureader", "msmarco"],
help="Format of input data")
# Training arguments
parser.add_argument("--output_dir", type=str, default="./output/joint-training",
help="Directory to save the models")
parser.add_argument("--num_train_epochs", type=int, default=3,
help="Number of training epochs")
parser.add_argument("--per_device_train_batch_size", type=int, default=4,
help="Batch size per GPU/CPU for training")
parser.add_argument("--per_device_eval_batch_size", type=int, default=8,
help="Batch size per GPU/CPU for evaluation")
parser.add_argument("--gradient_accumulation_steps", type=int, default=4,
help="Number of updates steps to accumulate")
parser.add_argument("--learning_rate", type=float, default=1e-5,
help="Initial learning rate for retriever")
parser.add_argument("--reranker_learning_rate", type=float, default=6e-5,
help="Initial learning rate for reranker")
parser.add_argument("--warmup_ratio", type=float, default=0.1,
help="Ratio of warmup steps")
parser.add_argument("--weight_decay", type=float, default=0.01,
help="Weight decay")
# Joint training specific
parser.add_argument("--joint_loss_weight", type=float, default=0.3,
help="Weight for joint distillation loss")
parser.add_argument("--update_retriever_steps", type=int, default=100,
help="Update retriever every N steps")
parser.add_argument("--use_dynamic_distillation", action="store_true",
help="Use dynamic listwise distillation")
# Model configuration
parser.add_argument("--retriever_train_group_size", type=int, default=8,
help="Number of passages per query for retriever")
parser.add_argument("--reranker_train_group_size", type=int, default=16,
help="Number of passages per query for reranker")
parser.add_argument("--temperature", type=float, default=0.02,
help="Temperature for InfoNCE loss")
# BGE-M3 specific
parser.add_argument("--use_dense", action="store_true", default=True,
help="Use dense embeddings")
parser.add_argument("--use_sparse", action="store_true",
help="Use sparse embeddings")
parser.add_argument("--use_colbert", action="store_true",
help="Use ColBERT embeddings")
parser.add_argument("--query_max_length", type=int, default=512,
help="Maximum query length")
parser.add_argument("--passage_max_length", type=int, default=8192,
help="Maximum passage length")
# Reranker specific
parser.add_argument("--reranker_max_length", type=int, default=512,
help="Maximum sequence length for reranker")
# Performance
parser.add_argument("--fp16", action="store_true",
help="Use mixed precision training")
parser.add_argument("--gradient_checkpointing", action="store_true",
help="Use gradient checkpointing")
parser.add_argument("--dataloader_num_workers", type=int, default=4,
help="Number of workers for data loading")
# Checkpointing
parser.add_argument("--save_steps", type=int, default=1000,
help="Save checkpoint every N steps")
parser.add_argument("--eval_steps", type=int, default=1000,
help="Evaluate every N steps")
parser.add_argument("--logging_steps", type=int, default=100,
help="Log every N steps")
parser.add_argument("--save_total_limit", type=int, default=3,
help="Maximum number of checkpoints to keep")
parser.add_argument("--resume_from_checkpoint", type=str, default=None,
help="Path to checkpoint to resume from")
# Distributed training
parser.add_argument("--local_rank", type=int, default=-1,
help="Local rank for distributed training")
# Other
parser.add_argument("--seed", type=int, default=42,
help="Random seed")
parser.add_argument("--overwrite_output_dir", action="store_true",
help="Overwrite the output directory")
args = parser.parse_args()
# Validate arguments
if os.path.exists(args.output_dir) and not args.overwrite_output_dir:
raise ValueError(f"Output directory {args.output_dir} already exists. Use --overwrite_output_dir to overwrite.")
return args
def main():
# Parse arguments
args = parse_args()
# Setup logging
setup_logging(
output_dir=args.output_dir,
log_level=logging.INFO,
log_to_console=True,
log_to_file=True
)
logger = get_logger(__name__)
logger.info("Starting joint training of BGE-M3 and BGE-Reranker")
logger.info(f"Arguments: {args}")
# Set random seed
set_random_seed(args.seed, deterministic=False)
# Setup distributed training if needed
device = None
if args.local_rank != -1:
setup_distributed()
from utils.config_loader import get_config
config = get_config()
device_type = config.get('ascend.device_type', None) # type: ignore[attr-defined]
if device_type == 'npu':
try:
import torch_npu
torch_npu.npu.set_device(args.local_rank)
device = torch.device("npu", args.local_rank)
logger.info(f"Distributed training enabled: local_rank={args.local_rank}, device=NPU, world_size={os.environ.get('WORLD_SIZE', 'N/A')}")
except ImportError:
logger.warning("torch_npu not installed, falling back to CPU")
device = torch.device("cpu")
elif torch.cuda.is_available():
torch.cuda.set_device(args.local_rank)
device = torch.device("cuda", args.local_rank)
logger.info(f"Distributed training enabled: local_rank={args.local_rank}, device=CUDA, world_size={os.environ.get('WORLD_SIZE', 'N/A')}")
else:
device = torch.device("cpu")
logger.info("Distributed training enabled: using CPU")
else:
from utils.config_loader import get_config
config = get_config()
device_type = config.get('ascend.device_type', None) # type: ignore[attr-defined]
if device_type == 'npu':
try:
import torch_npu
device = torch.device("npu")
logger.info("Using Ascend NPU")
except ImportError:
logger.warning("torch_npu not installed, falling back to CPU")
device = torch.device("cpu")
elif torch.cuda.is_available():
device = torch.device("cuda")
logger.info(f"Using CUDA GPU (device count: {torch.cuda.device_count()})")
else:
device = torch.device("cpu")
logger.info("Using CPU")
# Create configurations
retriever_config = BGEM3Config(
model_name_or_path=args.retriever_model,
cache_dir=args.cache_dir,
train_data_path=args.train_data,
eval_data_path=args.eval_data,
query_max_length=args.query_max_length,
passage_max_length=args.passage_max_length,
output_dir=os.path.join(args.output_dir, "retriever"),
num_train_epochs=args.num_train_epochs,
per_device_train_batch_size=args.per_device_train_batch_size,
per_device_eval_batch_size=args.per_device_eval_batch_size,
gradient_accumulation_steps=args.gradient_accumulation_steps,
learning_rate=args.learning_rate,
warmup_ratio=args.warmup_ratio,
weight_decay=args.weight_decay,
train_group_size=args.retriever_train_group_size,
temperature=args.temperature,
use_dense=args.use_dense,
use_sparse=args.use_sparse,
use_colbert=args.use_colbert,
fp16=args.fp16,
gradient_checkpointing=args.gradient_checkpointing,
dataloader_num_workers=args.dataloader_num_workers,
save_steps=args.save_steps,
eval_steps=args.eval_steps,
logging_steps=args.logging_steps,
save_total_limit=args.save_total_limit,
seed=args.seed,
local_rank=args.local_rank,
device=str(device)
)
reranker_config = BGERerankerConfig(
model_name_or_path=args.reranker_model,
cache_dir=args.cache_dir,
train_data_path=args.train_data,
eval_data_path=args.eval_data,
max_seq_length=args.reranker_max_length,
output_dir=os.path.join(args.output_dir, "reranker"),
num_train_epochs=args.num_train_epochs,
per_device_train_batch_size=args.per_device_train_batch_size,
per_device_eval_batch_size=args.per_device_eval_batch_size,
gradient_accumulation_steps=args.gradient_accumulation_steps,
learning_rate=args.reranker_learning_rate,
warmup_ratio=args.warmup_ratio,
weight_decay=args.weight_decay,
train_group_size=args.reranker_train_group_size,
joint_training=True, # type: ignore[call-arg]
joint_loss_weight=args.joint_loss_weight, # type: ignore[call-arg]
use_dynamic_listwise_distillation=args.use_dynamic_distillation, # type: ignore[call-arg]
update_retriever_steps=args.update_retriever_steps, # type: ignore[call-arg]
fp16=args.fp16,
gradient_checkpointing=args.gradient_checkpointing,
dataloader_num_workers=args.dataloader_num_workers,
save_steps=args.save_steps,
eval_steps=args.eval_steps,
logging_steps=args.logging_steps,
save_total_limit=args.save_total_limit,
seed=args.seed,
local_rank=args.local_rank,
device=str(device)
)
# Create joint config (use retriever config as base)
joint_config = retriever_config
joint_config.output_dir = args.output_dir
joint_config.joint_loss_weight = args.joint_loss_weight # type: ignore[attr-defined]
joint_config.use_dynamic_listwise_distillation = args.use_dynamic_distillation # type: ignore[attr-defined]
joint_config.update_retriever_steps = args.update_retriever_steps # type: ignore[attr-defined]
# Save configurations
if is_main_process():
os.makedirs(args.output_dir, exist_ok=True)
joint_config.save(os.path.join(args.output_dir, "joint_config.json"))
retriever_config.save(os.path.join(args.output_dir, "retriever_config.json"))
reranker_config.save(os.path.join(args.output_dir, "reranker_config.json"))
# Load tokenizers
logger.info("Loading tokenizers...")
retriever_tokenizer = AutoTokenizer.from_pretrained(
args.retriever_model,
cache_dir=args.cache_dir
)
reranker_tokenizer = AutoTokenizer.from_pretrained(
args.reranker_model,
cache_dir=args.cache_dir
)
# Prepare data
logger.info("Preparing data...")
preprocessor = DataPreprocessor(
tokenizer=retriever_tokenizer, # Use retriever tokenizer for preprocessing
max_query_length=args.query_max_length,
max_passage_length=args.passage_max_length
)
# Preprocess training data
train_path, val_path = preprocessor.preprocess_file(
input_path=args.train_data,
output_path=os.path.join(args.output_dir, "processed_train.jsonl"),
file_format=args.data_format,
validation_split=0.1 if args.eval_data is None else 0.0
)
# Use provided eval data or split validation
eval_path = args.eval_data if args.eval_data else val_path
# Create datasets
logger.info("Creating datasets...")
# Retriever dataset
retriever_train_dataset = BGEM3Dataset(
data_path=train_path,
tokenizer=retriever_tokenizer,
query_max_length=args.query_max_length,
passage_max_length=args.passage_max_length,
train_group_size=args.retriever_train_group_size,
is_train=True
)
# Reranker dataset
reranker_train_dataset = BGERerankerDataset(
data_path=train_path,
tokenizer=reranker_tokenizer,
max_seq_length=args.reranker_max_length, # type: ignore[call-arg]
train_group_size=args.reranker_train_group_size,
is_train=True
)
# Joint dataset
joint_train_dataset = JointDataset( # type: ignore[call-arg]
retriever_dataset=retriever_train_dataset, # type: ignore[call-arg]
reranker_dataset=reranker_train_dataset # type: ignore[call-arg]
)
# Evaluation datasets
joint_eval_dataset = None
if eval_path:
retriever_eval_dataset = BGEM3Dataset(
data_path=eval_path,
tokenizer=retriever_tokenizer,
query_max_length=args.query_max_length,
passage_max_length=args.passage_max_length,
train_group_size=args.retriever_train_group_size,
is_train=False
)
reranker_eval_dataset = BGERerankerDataset(
data_path=eval_path,
tokenizer=reranker_tokenizer,
max_seq_length=args.reranker_max_length, # type: ignore[call-arg]
train_group_size=args.reranker_train_group_size,
is_train=False
)
joint_eval_dataset = JointDataset( # type: ignore[call-arg]
retriever_dataset=retriever_eval_dataset, # type: ignore[call-arg]
reranker_dataset=reranker_eval_dataset # type: ignore[call-arg]
)
# Initialize models
logger.info("Loading models...")
retriever = BGEM3Model(
model_name_or_path=args.retriever_model,
use_dense=args.use_dense,
use_sparse=args.use_sparse,
use_colbert=args.use_colbert,
cache_dir=args.cache_dir
)
reranker = BGERerankerModel(
model_name_or_path=args.reranker_model,
cache_dir=args.cache_dir,
use_fp16=args.fp16
)
from utils.config_loader import get_config
logger.info("Retriever model loaded successfully (source: %s)", get_config().get('model_paths.source', 'huggingface'))
logger.info("Reranker model loaded successfully (source: %s)", get_config().get('model_paths.source', 'huggingface'))
# Initialize optimizers
from torch.optim import AdamW
retriever_optimizer = AdamW(
retriever.parameters(),
lr=args.learning_rate,
weight_decay=args.weight_decay
)
reranker_optimizer = AdamW(
reranker.parameters(),
lr=args.reranker_learning_rate,
weight_decay=args.weight_decay
)
# Initialize joint trainer
logger.info("Initializing joint trainer...")
trainer = JointTrainer(
config=joint_config,
retriever=retriever,
reranker=reranker,
retriever_optimizer=retriever_optimizer,
reranker_optimizer=reranker_optimizer,
device=device
)
# Create data loaders
from torch.utils.data import DataLoader
from utils.distributed import create_distributed_dataloader
if args.local_rank != -1:
train_dataloader = create_distributed_dataloader(
joint_train_dataset,
batch_size=args.per_device_train_batch_size,
shuffle=True,
num_workers=args.dataloader_num_workers,
seed=args.seed
)
eval_dataloader = None
if joint_eval_dataset:
eval_dataloader = create_distributed_dataloader(
joint_eval_dataset,
batch_size=args.per_device_eval_batch_size,
shuffle=False,
num_workers=args.dataloader_num_workers,
seed=args.seed
)
else:
train_dataloader = DataLoader(
joint_train_dataset,
batch_size=args.per_device_train_batch_size,
shuffle=True,
num_workers=args.dataloader_num_workers
)
eval_dataloader = None
if joint_eval_dataset:
eval_dataloader = DataLoader(
joint_eval_dataset,
batch_size=args.per_device_eval_batch_size,
num_workers=args.dataloader_num_workers
)
# Resume from checkpoint if specified
if args.resume_from_checkpoint:
trainer.load_checkpoint(args.resume_from_checkpoint)
# Train
logger.info("Starting joint training...")
try:
trainer.train(train_dataloader, eval_dataloader, args.num_train_epochs)
# Save final models
if is_main_process():
logger.info("Saving final models...")
final_dir = os.path.join(args.output_dir, "final_models")
trainer.save_models(final_dir)
# Save individual models
retriever_path = os.path.join(args.output_dir, "final_retriever")
reranker_path = os.path.join(args.output_dir, "final_reranker")
os.makedirs(retriever_path, exist_ok=True)
os.makedirs(reranker_path, exist_ok=True)
retriever.save_pretrained(retriever_path)
retriever_tokenizer.save_pretrained(retriever_path)
reranker.model.save_pretrained(reranker_path)
reranker_tokenizer.save_pretrained(reranker_path)
# Save training summary
summary = {
"training_args": vars(args),
"retriever_config": retriever_config.to_dict(),
"reranker_config": reranker_config.to_dict(),
"joint_config": joint_config.to_dict(),
"final_metrics": getattr(trainer, 'best_metric', None),
"total_steps": getattr(trainer, 'global_step', 0),
"training_completed": datetime.now().isoformat()
}
with open(os.path.join(args.output_dir, "training_summary.json"), 'w') as f:
json.dump(summary, f, indent=2)
logger.info(f"Joint training completed! Models saved to {args.output_dir}")
except Exception as e:
logger.error(f"Joint training failed with error: {e}", exc_info=True)
raise
if __name__ == "__main__":
main()