""" Joint training of BGE-M3 and BGE-Reranker using RocketQAv2 approach """ import os import sys import argparse # Set tokenizer parallelism to avoid warnings os.environ["TOKENIZERS_PARALLELISM"] = "false" import logging import json from datetime import datetime import torch from transformers import AutoTokenizer, set_seed # Add parent directory to path sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from config.m3_config import BGEM3Config from config.reranker_config import BGERerankerConfig from models.bge_m3 import BGEM3Model from models.bge_reranker import BGERerankerModel from data.dataset import BGEM3Dataset, BGERerankerDataset, JointDataset from data.preprocessing import DataPreprocessor from training.joint_trainer import JointTrainer from utils.logging import setup_logging, get_logger from utils.distributed import setup_distributed, is_main_process, set_random_seed logger = logging.getLogger(__name__) def parse_args(): parser = argparse.ArgumentParser(description="Joint training of BGE-M3 and BGE-Reranker") # Model arguments parser.add_argument("--retriever_model", type=str, default="BAAI/bge-m3", help="Path to retriever model") parser.add_argument("--reranker_model", type=str, default="BAAI/bge-reranker-base", help="Path to reranker model") parser.add_argument("--cache_dir", type=str, default="./cache/data", help="Directory to cache downloaded models") # Data arguments parser.add_argument("--train_data", type=str, required=True, help="Path to training data file") parser.add_argument("--eval_data", type=str, default=None, help="Path to evaluation data file") parser.add_argument("--data_format", type=str, default="auto", choices=["auto", "jsonl", "json", "csv", "tsv", "dureader", "msmarco"], help="Format of input data") # Training arguments parser.add_argument("--output_dir", type=str, default="./output/joint-training", help="Directory to save the models") parser.add_argument("--num_train_epochs", type=int, default=3, help="Number of training epochs") parser.add_argument("--per_device_train_batch_size", type=int, default=4, help="Batch size per GPU/CPU for training") parser.add_argument("--per_device_eval_batch_size", type=int, default=8, help="Batch size per GPU/CPU for evaluation") parser.add_argument("--gradient_accumulation_steps", type=int, default=4, help="Number of updates steps to accumulate") parser.add_argument("--learning_rate", type=float, default=1e-5, help="Initial learning rate for retriever") parser.add_argument("--reranker_learning_rate", type=float, default=6e-5, help="Initial learning rate for reranker") parser.add_argument("--warmup_ratio", type=float, default=0.1, help="Ratio of warmup steps") parser.add_argument("--weight_decay", type=float, default=0.01, help="Weight decay") # Joint training specific parser.add_argument("--joint_loss_weight", type=float, default=0.3, help="Weight for joint distillation loss") parser.add_argument("--update_retriever_steps", type=int, default=100, help="Update retriever every N steps") parser.add_argument("--use_dynamic_distillation", action="store_true", help="Use dynamic listwise distillation") # Model configuration parser.add_argument("--retriever_train_group_size", type=int, default=8, help="Number of passages per query for retriever") parser.add_argument("--reranker_train_group_size", type=int, default=16, help="Number of passages per query for reranker") parser.add_argument("--temperature", type=float, default=0.02, help="Temperature for InfoNCE loss") # BGE-M3 specific parser.add_argument("--use_dense", action="store_true", default=True, help="Use dense embeddings") parser.add_argument("--use_sparse", action="store_true", help="Use sparse embeddings") parser.add_argument("--use_colbert", action="store_true", help="Use ColBERT embeddings") parser.add_argument("--query_max_length", type=int, default=512, help="Maximum query length") parser.add_argument("--passage_max_length", type=int, default=8192, help="Maximum passage length") # Reranker specific parser.add_argument("--reranker_max_length", type=int, default=512, help="Maximum sequence length for reranker") # Performance parser.add_argument("--fp16", action="store_true", help="Use mixed precision training") parser.add_argument("--gradient_checkpointing", action="store_true", help="Use gradient checkpointing") parser.add_argument("--dataloader_num_workers", type=int, default=4, help="Number of workers for data loading") # Checkpointing parser.add_argument("--save_steps", type=int, default=1000, help="Save checkpoint every N steps") parser.add_argument("--eval_steps", type=int, default=1000, help="Evaluate every N steps") parser.add_argument("--logging_steps", type=int, default=100, help="Log every N steps") parser.add_argument("--save_total_limit", type=int, default=3, help="Maximum number of checkpoints to keep") parser.add_argument("--resume_from_checkpoint", type=str, default=None, help="Path to checkpoint to resume from") # Distributed training parser.add_argument("--local_rank", type=int, default=-1, help="Local rank for distributed training") # Other parser.add_argument("--seed", type=int, default=42, help="Random seed") parser.add_argument("--overwrite_output_dir", action="store_true", help="Overwrite the output directory") args = parser.parse_args() # Validate arguments if os.path.exists(args.output_dir) and not args.overwrite_output_dir: raise ValueError(f"Output directory {args.output_dir} already exists. Use --overwrite_output_dir to overwrite.") return args def main(): # Parse arguments args = parse_args() # Setup logging setup_logging( output_dir=args.output_dir, log_level=logging.INFO, log_to_console=True, log_to_file=True ) logger = get_logger(__name__) logger.info("Starting joint training of BGE-M3 and BGE-Reranker") logger.info(f"Arguments: {args}") # Set random seed set_random_seed(args.seed, deterministic=False) # Setup distributed training if needed device = None if args.local_rank != -1: setup_distributed() from utils.config_loader import get_config config = get_config() device_type = config.get('ascend.device_type', None) # type: ignore[attr-defined] if device_type == 'npu': try: import torch_npu torch_npu.npu.set_device(args.local_rank) device = torch.device("npu", args.local_rank) logger.info(f"Distributed training enabled: local_rank={args.local_rank}, device=NPU, world_size={os.environ.get('WORLD_SIZE', 'N/A')}") except ImportError: logger.warning("torch_npu not installed, falling back to CPU") device = torch.device("cpu") elif torch.cuda.is_available(): torch.cuda.set_device(args.local_rank) device = torch.device("cuda", args.local_rank) logger.info(f"Distributed training enabled: local_rank={args.local_rank}, device=CUDA, world_size={os.environ.get('WORLD_SIZE', 'N/A')}") else: device = torch.device("cpu") logger.info("Distributed training enabled: using CPU") else: from utils.config_loader import get_config config = get_config() device_type = config.get('ascend.device_type', None) # type: ignore[attr-defined] if device_type == 'npu': try: import torch_npu device = torch.device("npu") logger.info("Using Ascend NPU") except ImportError: logger.warning("torch_npu not installed, falling back to CPU") device = torch.device("cpu") elif torch.cuda.is_available(): device = torch.device("cuda") logger.info(f"Using CUDA GPU (device count: {torch.cuda.device_count()})") else: device = torch.device("cpu") logger.info("Using CPU") # Create configurations retriever_config = BGEM3Config( model_name_or_path=args.retriever_model, cache_dir=args.cache_dir, train_data_path=args.train_data, eval_data_path=args.eval_data, query_max_length=args.query_max_length, passage_max_length=args.passage_max_length, output_dir=os.path.join(args.output_dir, "retriever"), num_train_epochs=args.num_train_epochs, per_device_train_batch_size=args.per_device_train_batch_size, per_device_eval_batch_size=args.per_device_eval_batch_size, gradient_accumulation_steps=args.gradient_accumulation_steps, learning_rate=args.learning_rate, warmup_ratio=args.warmup_ratio, weight_decay=args.weight_decay, train_group_size=args.retriever_train_group_size, temperature=args.temperature, use_dense=args.use_dense, use_sparse=args.use_sparse, use_colbert=args.use_colbert, fp16=args.fp16, gradient_checkpointing=args.gradient_checkpointing, dataloader_num_workers=args.dataloader_num_workers, save_steps=args.save_steps, eval_steps=args.eval_steps, logging_steps=args.logging_steps, save_total_limit=args.save_total_limit, seed=args.seed, local_rank=args.local_rank, device=str(device) ) reranker_config = BGERerankerConfig( model_name_or_path=args.reranker_model, cache_dir=args.cache_dir, train_data_path=args.train_data, eval_data_path=args.eval_data, max_seq_length=args.reranker_max_length, output_dir=os.path.join(args.output_dir, "reranker"), num_train_epochs=args.num_train_epochs, per_device_train_batch_size=args.per_device_train_batch_size, per_device_eval_batch_size=args.per_device_eval_batch_size, gradient_accumulation_steps=args.gradient_accumulation_steps, learning_rate=args.reranker_learning_rate, warmup_ratio=args.warmup_ratio, weight_decay=args.weight_decay, train_group_size=args.reranker_train_group_size, joint_training=True, # type: ignore[call-arg] joint_loss_weight=args.joint_loss_weight, # type: ignore[call-arg] use_dynamic_listwise_distillation=args.use_dynamic_distillation, # type: ignore[call-arg] update_retriever_steps=args.update_retriever_steps, # type: ignore[call-arg] fp16=args.fp16, gradient_checkpointing=args.gradient_checkpointing, dataloader_num_workers=args.dataloader_num_workers, save_steps=args.save_steps, eval_steps=args.eval_steps, logging_steps=args.logging_steps, save_total_limit=args.save_total_limit, seed=args.seed, local_rank=args.local_rank, device=str(device) ) # Create joint config (use retriever config as base) joint_config = retriever_config joint_config.output_dir = args.output_dir joint_config.joint_loss_weight = args.joint_loss_weight # type: ignore[attr-defined] joint_config.use_dynamic_listwise_distillation = args.use_dynamic_distillation # type: ignore[attr-defined] joint_config.update_retriever_steps = args.update_retriever_steps # type: ignore[attr-defined] # Save configurations if is_main_process(): os.makedirs(args.output_dir, exist_ok=True) joint_config.save(os.path.join(args.output_dir, "joint_config.json")) retriever_config.save(os.path.join(args.output_dir, "retriever_config.json")) reranker_config.save(os.path.join(args.output_dir, "reranker_config.json")) # Load tokenizers logger.info("Loading tokenizers...") retriever_tokenizer = AutoTokenizer.from_pretrained( args.retriever_model, cache_dir=args.cache_dir ) reranker_tokenizer = AutoTokenizer.from_pretrained( args.reranker_model, cache_dir=args.cache_dir ) # Prepare data logger.info("Preparing data...") preprocessor = DataPreprocessor( tokenizer=retriever_tokenizer, # Use retriever tokenizer for preprocessing max_query_length=args.query_max_length, max_passage_length=args.passage_max_length ) # Preprocess training data train_path, val_path = preprocessor.preprocess_file( input_path=args.train_data, output_path=os.path.join(args.output_dir, "processed_train.jsonl"), file_format=args.data_format, validation_split=0.1 if args.eval_data is None else 0.0 ) # Use provided eval data or split validation eval_path = args.eval_data if args.eval_data else val_path # Create datasets logger.info("Creating datasets...") # Retriever dataset retriever_train_dataset = BGEM3Dataset( data_path=train_path, tokenizer=retriever_tokenizer, query_max_length=args.query_max_length, passage_max_length=args.passage_max_length, train_group_size=args.retriever_train_group_size, is_train=True ) # Reranker dataset reranker_train_dataset = BGERerankerDataset( data_path=train_path, tokenizer=reranker_tokenizer, max_seq_length=args.reranker_max_length, # type: ignore[call-arg] train_group_size=args.reranker_train_group_size, is_train=True ) # Joint dataset joint_train_dataset = JointDataset( # type: ignore[call-arg] retriever_dataset=retriever_train_dataset, # type: ignore[call-arg] reranker_dataset=reranker_train_dataset # type: ignore[call-arg] ) # Evaluation datasets joint_eval_dataset = None if eval_path: retriever_eval_dataset = BGEM3Dataset( data_path=eval_path, tokenizer=retriever_tokenizer, query_max_length=args.query_max_length, passage_max_length=args.passage_max_length, train_group_size=args.retriever_train_group_size, is_train=False ) reranker_eval_dataset = BGERerankerDataset( data_path=eval_path, tokenizer=reranker_tokenizer, max_seq_length=args.reranker_max_length, # type: ignore[call-arg] train_group_size=args.reranker_train_group_size, is_train=False ) joint_eval_dataset = JointDataset( # type: ignore[call-arg] retriever_dataset=retriever_eval_dataset, # type: ignore[call-arg] reranker_dataset=reranker_eval_dataset # type: ignore[call-arg] ) # Initialize models logger.info("Loading models...") retriever = BGEM3Model( model_name_or_path=args.retriever_model, use_dense=args.use_dense, use_sparse=args.use_sparse, use_colbert=args.use_colbert, cache_dir=args.cache_dir ) reranker = BGERerankerModel( model_name_or_path=args.reranker_model, cache_dir=args.cache_dir, use_fp16=args.fp16 ) from utils.config_loader import get_config logger.info("Retriever model loaded successfully (source: %s)", get_config().get('model_paths.source', 'huggingface')) logger.info("Reranker model loaded successfully (source: %s)", get_config().get('model_paths.source', 'huggingface')) # Initialize optimizers from torch.optim import AdamW retriever_optimizer = AdamW( retriever.parameters(), lr=args.learning_rate, weight_decay=args.weight_decay ) reranker_optimizer = AdamW( reranker.parameters(), lr=args.reranker_learning_rate, weight_decay=args.weight_decay ) # Initialize joint trainer logger.info("Initializing joint trainer...") trainer = JointTrainer( config=joint_config, retriever=retriever, reranker=reranker, retriever_optimizer=retriever_optimizer, reranker_optimizer=reranker_optimizer, device=device ) # Create data loaders from torch.utils.data import DataLoader from utils.distributed import create_distributed_dataloader if args.local_rank != -1: train_dataloader = create_distributed_dataloader( joint_train_dataset, batch_size=args.per_device_train_batch_size, shuffle=True, num_workers=args.dataloader_num_workers, seed=args.seed ) eval_dataloader = None if joint_eval_dataset: eval_dataloader = create_distributed_dataloader( joint_eval_dataset, batch_size=args.per_device_eval_batch_size, shuffle=False, num_workers=args.dataloader_num_workers, seed=args.seed ) else: train_dataloader = DataLoader( joint_train_dataset, batch_size=args.per_device_train_batch_size, shuffle=True, num_workers=args.dataloader_num_workers ) eval_dataloader = None if joint_eval_dataset: eval_dataloader = DataLoader( joint_eval_dataset, batch_size=args.per_device_eval_batch_size, num_workers=args.dataloader_num_workers ) # Resume from checkpoint if specified if args.resume_from_checkpoint: trainer.load_checkpoint(args.resume_from_checkpoint) # Train logger.info("Starting joint training...") try: trainer.train(train_dataloader, eval_dataloader, args.num_train_epochs) # Save final models if is_main_process(): logger.info("Saving final models...") final_dir = os.path.join(args.output_dir, "final_models") trainer.save_models(final_dir) # Save individual models retriever_path = os.path.join(args.output_dir, "final_retriever") reranker_path = os.path.join(args.output_dir, "final_reranker") os.makedirs(retriever_path, exist_ok=True) os.makedirs(reranker_path, exist_ok=True) retriever.save_pretrained(retriever_path) retriever_tokenizer.save_pretrained(retriever_path) reranker.model.save_pretrained(reranker_path) reranker_tokenizer.save_pretrained(reranker_path) # Save training summary summary = { "training_args": vars(args), "retriever_config": retriever_config.to_dict(), "reranker_config": reranker_config.to_dict(), "joint_config": joint_config.to_dict(), "final_metrics": getattr(trainer, 'best_metric', None), "total_steps": getattr(trainer, 'global_step', 0), "training_completed": datetime.now().isoformat() } with open(os.path.join(args.output_dir, "training_summary.json"), 'w') as f: json.dump(summary, f, indent=2) logger.info(f"Joint training completed! Models saved to {args.output_dir}") except Exception as e: logger.error(f"Joint training failed with error: {e}", exc_info=True) raise if __name__ == "__main__": main()