#!/usr/bin/env python3 """ BGE-Reranker Cross-Encoder Model Training Script Uses the refactored dataset pipeline with flat pairs schema """ import os import sys import argparse import logging import torch from torch.utils.data import DataLoader from transformers import AutoTokenizer, get_linear_schedule_with_warmup from transformers.optimization import AdamW import torch.nn.functional as F # Add parent directory to path sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from data.dataset import create_reranker_dataset, collate_reranker_batch, migrate_nested_to_flat from models.bge_reranker import BGERerankerModel from training.reranker_trainer import BGERerankerTrainer from utils.logging import setup_logging, get_logger logger = get_logger(__name__) def parse_args(): """Parse command line arguments for reranker training""" parser = argparse.ArgumentParser(description="Train BGE-Reranker Cross-Encoder Model") # Model arguments parser.add_argument("--model_name_or_path", type=str, default="BAAI/bge-reranker-base", help="Path to pretrained BGE-Reranker model") parser.add_argument("--cache_dir", type=str, default="./cache/data", help="Directory to cache downloaded models") # Data arguments parser.add_argument("--reranker_data", type=str, required=True, help="Path to reranker training data (reranker_data.jsonl)") parser.add_argument("--eval_data", type=str, default=None, help="Path to evaluation data") parser.add_argument("--legacy_data", type=str, default=None, help="Path to legacy nested data (will be migrated)") # Training arguments parser.add_argument("--output_dir", type=str, default="./output/bge-reranker", help="Directory to save the trained model") parser.add_argument("--num_train_epochs", type=int, default=3, help="Number of training epochs") parser.add_argument("--per_device_train_batch_size", type=int, default=8, help="Batch size per GPU/CPU for training") parser.add_argument("--per_device_eval_batch_size", type=int, default=16, help="Batch size per GPU/CPU for evaluation") parser.add_argument("--learning_rate", type=float, default=6e-5, help="Learning rate") parser.add_argument("--weight_decay", type=float, default=0.01, help="Weight decay") parser.add_argument("--warmup_ratio", type=float, default=0.1, help="Warmup ratio") # Model specific arguments parser.add_argument("--query_max_length", type=int, default=64, help="Maximum query length") parser.add_argument("--passage_max_length", type=int, default=448, help="Maximum passage length") parser.add_argument("--loss_type", type=str, default="cross_entropy", choices=["cross_entropy", "binary_cross_entropy", "listwise_ce"], help="Loss function type") # Enhanced features parser.add_argument("--use_score_weighting", action="store_true", help="Use score weighting in loss calculation") parser.add_argument("--label_smoothing", type=float, default=0.0, help="Label smoothing for cross entropy loss") # System arguments parser.add_argument("--fp16", action="store_true", help="Use 16-bit floating point precision") parser.add_argument("--dataloader_num_workers", type=int, default=4, help="Number of workers for data loading") parser.add_argument("--seed", type=int, default=42, help="Random seed") parser.add_argument("--log_level", type=str, default="INFO", choices=["DEBUG", "INFO", "WARNING", "ERROR"], help="Logging level") # Backward compatibility parser.add_argument("--legacy_support", action="store_true", default=True, help="Enable legacy nested format support") parser.add_argument("--migrate_legacy", action="store_true", help="Migrate legacy nested format to flat format") return parser.parse_args() def setup_model_and_tokenizer(args): """Setup BGE-Reranker model and tokenizer""" logger.info(f"Loading BGE-Reranker model from {args.model_name_or_path}") # Load tokenizer tokenizer = AutoTokenizer.from_pretrained( args.model_name_or_path, cache_dir=args.cache_dir, trust_remote_code=True ) # Load model model = BGERerankerModel( model_name_or_path=args.model_name_or_path, cache_dir=args.cache_dir ) return model, tokenizer def migrate_legacy_data(args): """Migrate legacy nested data to flat format if needed""" if args.legacy_data and args.migrate_legacy: logger.info("Migrating legacy nested data to flat format...") # Create output file name legacy_base = os.path.splitext(args.legacy_data)[0] migrated_file = f"{legacy_base}_migrated_flat.jsonl" # Migrate the data migrate_nested_to_flat(args.legacy_data, migrated_file) logger.info(f"✅ Migrated legacy data to {migrated_file}") return migrated_file return None def setup_datasets(args, tokenizer): """Setup reranker datasets""" logger.info("Setting up reranker datasets...") # Migrate legacy data if needed migrated_file = migrate_legacy_data(args) # Primary reranker data train_dataset = create_reranker_dataset( data_path=args.reranker_data, tokenizer=tokenizer, query_max_length=args.query_max_length, passage_max_length=args.passage_max_length, is_train=True, legacy_support=args.legacy_support ) logger.info(f"Loaded {len(train_dataset)} training examples from {args.reranker_data}") logger.info(f"Detected format: {train_dataset.format_type}") # Add migrated legacy data if available if migrated_file: logger.info(f"Loading migrated legacy data from {migrated_file}") migrated_dataset = create_reranker_dataset( data_path=migrated_file, tokenizer=tokenizer, query_max_length=args.query_max_length, passage_max_length=args.passage_max_length, is_train=True, legacy_support=False # Migrated data is in flat format ) logger.info(f"Loaded {len(migrated_dataset)} migrated examples") # Combine datasets train_dataset.data.extend(migrated_dataset.data) logger.info(f"Total training examples: {len(train_dataset)}") # Evaluation dataset eval_dataset = None if args.eval_data: eval_dataset = create_reranker_dataset( data_path=args.eval_data, tokenizer=tokenizer, query_max_length=args.query_max_length, passage_max_length=args.passage_max_length, is_train=False, legacy_support=args.legacy_support ) logger.info(f"Loaded {len(eval_dataset)} evaluation examples") return train_dataset, eval_dataset def setup_loss_function(args): """Setup cross-encoder loss function""" if args.loss_type == "cross_entropy": if args.label_smoothing > 0: loss_fn = torch.nn.CrossEntropyLoss(label_smoothing=args.label_smoothing) else: loss_fn = torch.nn.CrossEntropyLoss() elif args.loss_type == "binary_cross_entropy": loss_fn = torch.nn.BCEWithLogitsLoss() elif args.loss_type == "listwise_ce": # Custom listwise cross-entropy loss loss_fn = torch.nn.CrossEntropyLoss() else: raise ValueError(f"Unknown loss type: {args.loss_type}") logger.info(f"Using loss function: {args.loss_type}") return loss_fn def train_epoch(model, dataloader, optimizer, scheduler, loss_fn, device, args): """Train for one epoch""" model.train() total_loss = 0 num_batches = 0 for batch_idx, batch in enumerate(dataloader): # Move batch to device input_ids = batch['input_ids'].to(device) attention_mask = batch['attention_mask'].to(device) labels = batch['labels'].to(device) # Forward pass outputs = model(input_ids=input_ids, attention_mask=attention_mask) logits = outputs.logits if hasattr(outputs, 'logits') else outputs # Calculate loss if args.loss_type == "binary_cross_entropy": # For binary classification loss = loss_fn(logits.squeeze(), labels.float()) else: # For multi-class classification loss = loss_fn(logits, labels) # Backward pass loss.backward() optimizer.step() scheduler.step() optimizer.zero_grad() total_loss += loss.item() num_batches += 1 # Log progress if batch_idx % 100 == 0: logger.info(f"Batch {batch_idx}/{len(dataloader)}, Loss: {loss.item():.4f}") return total_loss / num_batches def evaluate(model, dataloader, loss_fn, device, args): """Evaluate the model""" model.eval() total_loss = 0 correct = 0 total = 0 with torch.no_grad(): for batch in dataloader: input_ids = batch['input_ids'].to(device) attention_mask = batch['attention_mask'].to(device) labels = batch['labels'].to(device) # Forward pass outputs = model(input_ids=input_ids, attention_mask=attention_mask) logits = outputs.logits if hasattr(outputs, 'logits') else outputs # Calculate loss if args.loss_type == "binary_cross_entropy": loss = loss_fn(logits.squeeze(), labels.float()) # Calculate accuracy for binary classification predictions = (torch.sigmoid(logits.squeeze()) > 0.5).long() else: loss = loss_fn(logits, labels) # Calculate accuracy for multi-class classification predictions = torch.argmax(logits, dim=-1) total_loss += loss.item() correct += (predictions == labels).sum().item() total += labels.size(0) accuracy = correct / total if total > 0 else 0 return total_loss / len(dataloader), accuracy def main(): """Main training function""" args = parse_args() # Setup logging setup_logging( output_dir=args.output_dir, log_level=args.log_level, log_to_file=True, log_to_console=True ) logger.info("🚀 Starting BGE-Reranker Training") logger.info(f"Arguments: {vars(args)}") # Set random seed torch.manual_seed(args.seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(args.seed) # Setup device device = torch.device("cuda" if torch.cuda.is_available() and not args.fp16 else "cpu") logger.info(f"Using device: {device}") # Setup model and tokenizer model, tokenizer = setup_model_and_tokenizer(args) model.to(device) # Setup datasets train_dataset, eval_dataset = setup_datasets(args, tokenizer) # Setup data loaders train_dataloader = DataLoader( train_dataset, batch_size=args.per_device_train_batch_size, shuffle=True, num_workers=args.dataloader_num_workers, collate_fn=collate_reranker_batch, pin_memory=torch.cuda.is_available() ) eval_dataloader = None if eval_dataset: eval_dataloader = DataLoader( eval_dataset, batch_size=args.per_device_eval_batch_size, shuffle=False, num_workers=args.dataloader_num_workers, collate_fn=collate_reranker_batch, pin_memory=torch.cuda.is_available() ) # Setup optimizer and scheduler optimizer = AdamW( model.parameters(), lr=args.learning_rate, weight_decay=args.weight_decay ) total_steps = len(train_dataloader) * args.num_train_epochs warmup_steps = int(total_steps * args.warmup_ratio) scheduler = get_linear_schedule_with_warmup( optimizer, num_warmup_steps=warmup_steps, num_training_steps=total_steps ) # Setup loss function loss_fn = setup_loss_function(args) # Training loop logger.info("🎯 Starting training...") best_eval_loss = float('inf') for epoch in range(args.num_train_epochs): logger.info(f"Epoch {epoch + 1}/{args.num_train_epochs}") # Train train_loss = train_epoch(model, train_dataloader, optimizer, scheduler, loss_fn, device, args) logger.info(f"Training loss: {train_loss:.4f}") # Evaluate if eval_dataloader: eval_loss, eval_accuracy = evaluate(model, eval_dataloader, loss_fn, device, args) logger.info(f"Evaluation loss: {eval_loss:.4f}, Accuracy: {eval_accuracy:.4f}") # Save best model if eval_loss < best_eval_loss: best_eval_loss = eval_loss best_model_path = os.path.join(args.output_dir, "best_model") os.makedirs(best_model_path, exist_ok=True) model.save_pretrained(best_model_path) tokenizer.save_pretrained(best_model_path) logger.info(f"Saved best model to {best_model_path}") # Save final model final_model_path = os.path.join(args.output_dir, "final_model") os.makedirs(final_model_path, exist_ok=True) model.save_pretrained(final_model_path) tokenizer.save_pretrained(final_model_path) logger.info(f"✅ Training completed! Final model saved to {final_model_path}") # Show training summary logger.info("📊 Training Summary:") logger.info(f" Total examples: {len(train_dataset)}") logger.info(f" Format detected: {train_dataset.format_type}") logger.info(f" Epochs: {args.num_train_epochs}") logger.info(f" Learning rate: {args.learning_rate}") logger.info(f" Batch size: {args.per_device_train_batch_size}") logger.info(f" Loss function: {args.loss_type}") if __name__ == "__main__": main()