import torch import torch.nn as nn from torch.utils.data import DataLoader from torch.optim import AdamW from typing import Dict, Optional, List, Tuple, Any import logging import numpy as np from torch.nn import functional as F from tqdm import tqdm import os from .trainer import BaseTrainer from models.bge_reranker import BGERerankerModel from models.losses import ListwiseCrossEntropyLoss, PairwiseMarginLoss from data.dataset import BGERerankerDataset from config.reranker_config import BGERerankerConfig from evaluation.metrics import compute_reranker_metrics from utils.distributed import is_main_process # AMP imports from torch.cuda.amp.autocast_mode import autocast from torch.cuda.amp.grad_scaler import GradScaler # Scheduler import try: from transformers import get_linear_schedule_with_warmup except ImportError: get_linear_schedule_with_warmup = None logger = logging.getLogger(__name__) class BGERerankerTrainer(BaseTrainer): """ Trainer for BGE-Reranker cross-encoder model Notes: - All parameters (config, model, tokenizer, optimizer, scheduler, device, etc.) should be passed from config-driven training scripts. - This class does not load config directly; pass config values from your main script for consistency. """ def __init__( self, config: BGERerankerConfig, model: Optional[BGERerankerModel] = None, tokenizer=None, optimizer: Optional[torch.optim.Optimizer] = None, scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None, device: Optional[torch.device] = None ): # Initialize model if not provided if model is None: model = BGERerankerModel( model_name_or_path=config.model_name_or_path, num_labels=config.num_labels, cache_dir=config.cache_dir, use_fp16=config.fp16, gradient_checkpointing=config.gradient_checkpointing ) # Initialize optimizer if not provided if optimizer is None: optimizer = self._create_optimizer(model, config) # Initialize parent class super().__init__(config, model, optimizer, scheduler, device) self.tokenizer = tokenizer self.config = config # Initialize losses based on config if config.loss_type == "listwise_ce": self.loss_fn = ListwiseCrossEntropyLoss( temperature=getattr(config, 'temperature', 1.0) ) elif config.loss_type == "pairwise_margin": self.loss_fn = PairwiseMarginLoss(margin=config.margin) elif config.loss_type == "combined": self.listwise_loss = ListwiseCrossEntropyLoss( temperature=getattr(config, 'temperature', 1.0) ) self.pairwise_loss = PairwiseMarginLoss(margin=config.margin) else: self.loss_fn = nn.BCEWithLogitsLoss() # Knowledge distillation setup self.use_kd = config.use_knowledge_distillation if self.use_kd: self.kd_loss_fn = nn.KLDivLoss(reduction='batchmean') self.kd_temperature = config.distillation_temperature self.kd_loss_weight = config.distillation_alpha # Hard negative configuration self.hard_negative_ratio = config.hard_negative_ratio self.use_denoising = config.use_denoising self.denoise_threshold = getattr(config, 'denoise_confidence_threshold', 0.5) def _create_optimizer(self, model: nn.Module, config: BGERerankerConfig) -> torch.optim.Optimizer: """Create optimizer with proper parameter groups""" # Separate parameters no_decay = ['bias', 'LayerNorm.weight'] optimizer_grouped_parameters = [ { 'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': config.weight_decay, }, { 'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0, } ] optimizer = AdamW( optimizer_grouped_parameters, lr=config.learning_rate, eps=config.adam_epsilon, betas=(config.adam_beta1, config.adam_beta2) ) return optimizer def create_scheduler(self, num_training_steps: int) -> Optional[object]: """Create learning rate scheduler""" if get_linear_schedule_with_warmup is None: logger.warning("transformers not available, using constant learning rate") return None num_warmup_steps = self.config.warmup_steps or int(num_training_steps * self.config.warmup_ratio) self.scheduler = get_linear_schedule_with_warmup( self.optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps ) return self.scheduler def compute_loss(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor: """Compute reranker loss""" # Get model outputs outputs = self.model.forward( input_ids=batch['input_ids'], attention_mask=batch['attention_mask'], return_dict=True ) if isinstance(outputs, dict): logits = outputs['logits'] else: logits = outputs.scores labels = batch['labels'] # Handle padding labels if labels.dim() > 1: # Remove padding (-1 labels) valid_mask = labels != -1 if valid_mask.any(): logits = logits[valid_mask] labels = labels[valid_mask] else: # If all labels are padding, return zero loss return torch.tensor(0.0, device=logits.device, requires_grad=True) # Handle different loss types if self.config.loss_type == "listwise_ce": # Listwise loss expects grouped format try: groups = [self.config.train_group_size] * (len(labels) // self.config.train_group_size) if groups: loss = self.loss_fn(logits.squeeze(-1), labels.float(), groups) else: loss = F.binary_cross_entropy_with_logits(logits.squeeze(-1), labels.float()) except Exception as e: # Fallback to BCE with specific error logging logger.warning(f"Listwise CE loss failed: {e}. Falling back to BCE.") loss = F.binary_cross_entropy_with_logits(logits.squeeze(-1), labels.float()) elif self.config.loss_type == "pairwise_margin": # Extract positive and negative scores pos_mask = labels == 1 neg_mask = labels == 0 if pos_mask.any() and neg_mask.any(): groups = [self.config.train_group_size] * (len(labels) // self.config.train_group_size) try: loss = self.loss_fn(logits.squeeze(-1), labels.float(), groups) except Exception as e: # Fallback to BCE with specific error logging logger.warning(f"Pairwise margin loss failed: {e}. Falling back to BCE.") loss = F.binary_cross_entropy_with_logits(logits.squeeze(-1), labels.float()) else: # Fallback to BCE if no proper pairs logger.warning("No positive-negative pairs found for pairwise loss. Using BCE.") loss = F.binary_cross_entropy_with_logits(logits.squeeze(-1), labels.float()) elif self.config.loss_type == "combined": # Combine listwise and pairwise losses try: groups = [self.config.train_group_size] * (len(labels) // self.config.train_group_size) if groups: listwise_loss = self.listwise_loss(logits.squeeze(-1), labels.float(), groups) pairwise_loss = self.pairwise_loss(logits.squeeze(-1), labels.float(), groups) loss = self.config.listwise_weight * listwise_loss + self.config.pairwise_weight * pairwise_loss else: loss = F.binary_cross_entropy_with_logits(logits.squeeze(-1), labels.float()) except Exception as e: # Fallback to BCE with specific error logging logger.warning(f"Combined loss computation failed: {e}. Falling back to BCE.") loss = F.binary_cross_entropy_with_logits(logits.squeeze(-1), labels.float()) else: # Default binary cross-entropy loss = F.binary_cross_entropy_with_logits(logits.squeeze(-1), labels.float()) # Add knowledge distillation loss if enabled if self.use_kd and 'teacher_scores' in batch: teacher_scores = batch['teacher_scores'] student_probs = F.log_softmax(logits / self.kd_temperature, dim=-1) teacher_probs = F.softmax(teacher_scores / self.kd_temperature, dim=-1) kd_loss = self.kd_loss_fn(student_probs, teacher_probs) * (self.kd_temperature ** 2) loss = (1 - self.kd_loss_weight) * loss + self.kd_loss_weight * kd_loss # Apply denoising if enabled if self.use_denoising and 'teacher_scores' in batch: # Denoise by confidence thresholding confidence_scores = torch.sigmoid(batch['teacher_scores']) confidence_mask = confidence_scores > self.denoise_threshold if confidence_mask.any(): loss = loss * confidence_mask.float().mean() return loss def evaluate_step(self, batch: Dict[str, torch.Tensor]) -> Dict[str, float]: """Evaluate a single batch""" with torch.no_grad(): # Get model outputs outputs = self.model.forward( input_ids=batch['input_ids'], attention_mask=batch['attention_mask'], return_dict=True ) if isinstance(outputs, dict): logits = outputs['logits'] else: logits = outputs.scores labels = batch['labels'] # Handle padding labels if labels.dim() > 1: valid_mask = labels != -1 if valid_mask.any(): logits = logits[valid_mask] labels = labels[valid_mask] else: return {'loss': 0.0, 'accuracy': 0.0} # Compute loss loss = self.compute_loss(batch) # Compute metrics based on group size if hasattr(self.config, 'train_group_size') and self.config.train_group_size > 1: # Listwise evaluation groups = [self.config.train_group_size] * (len(labels) // self.config.train_group_size) if groups: try: metrics = compute_reranker_metrics( logits.squeeze(-1).cpu().numpy(), labels.cpu().numpy(), groups=groups, k_values=[1, 3, 5, 10] ) metrics['loss'] = loss.item() except Exception as e: logger.warning(f"Error computing reranker metrics: {e}") metrics = {'loss': loss.item(), 'accuracy': 0.0} else: metrics = {'loss': loss.item(), 'accuracy': 0.0} else: # Binary evaluation predictions = (torch.sigmoid(logits.squeeze(-1)) > 0.5).float() accuracy = (predictions == labels.float()).float().mean() metrics = { 'accuracy': accuracy.item(), 'loss': loss.item() } # Add AUC if we have both positive and negative examples if len(torch.unique(labels)) > 1: try: from sklearn.metrics import roc_auc_score auc = roc_auc_score(labels.cpu().numpy(), torch.sigmoid(logits.squeeze(-1)).cpu().numpy()) metrics['auc'] = float(auc) except Exception as e: logger.warning(f"Error computing AUC: {e}") return metrics def train_with_hard_negative_mining( self, train_dataset: BGERerankerDataset, eval_dataset: Optional[BGERerankerDataset] = None, retriever_model: Optional[nn.Module] = None ): """ Train with periodic hard negative mining from retriever """ if retriever_model is None: logger.warning("No retriever model provided for hard negative mining. Using standard training.") train_dataloader = self._create_dataloader(train_dataset, is_train=True) eval_dataloader = self._create_dataloader(eval_dataset, is_train=False) if eval_dataset else None # Only proceed if we have a valid training dataloader if train_dataloader is not None: return super().train(train_dataloader, eval_dataloader) else: logger.error("No valid training dataloader created") return # Training loop with hard negative updates for epoch in range(self.config.num_train_epochs): # Mine hard negatives at the beginning of each epoch if epoch > 0: logger.info(f"Mining hard negatives for epoch {epoch + 1}") self._mine_hard_negatives(train_dataset, retriever_model) # Create data loader with new negatives train_dataloader = self._create_dataloader(train_dataset, is_train=True) eval_dataloader = self._create_dataloader(eval_dataset, is_train=False) if eval_dataset else None # Only proceed if we have a valid training dataloader if train_dataloader is not None: # Train one epoch super().train(train_dataloader, eval_dataloader, num_epochs=1) else: logger.error(f"No valid training dataloader for epoch {epoch + 1}") continue # Update current epoch self.current_epoch = epoch def _create_dataloader(self, dataset, is_train: bool = True) -> Optional[DataLoader]: """Create data loader""" if dataset is None: return None from data.dataset import collate_reranker_batch batch_size = self.config.per_device_train_batch_size if is_train else self.config.per_device_eval_batch_size # Setup distributed sampler if needed sampler = None if self.is_distributed: from torch.utils.data.distributed import DistributedSampler sampler = DistributedSampler(dataset, shuffle=is_train) return DataLoader( dataset, batch_size=batch_size, shuffle=is_train and sampler is None, sampler=sampler, num_workers=self.config.dataloader_num_workers, pin_memory=self.config.dataloader_pin_memory, drop_last=is_train, collate_fn=collate_reranker_batch ) def _mine_hard_negatives( self, dataset: BGERerankerDataset, retriever_model: Any ): """Mine hard negatives using the retriever model and update the dataset's negatives.""" logger.info("Starting hard negative mining using retriever model...") num_hard_negatives = getattr(self.config, 'num_hard_negatives', 10) retriever_batch_size = getattr(self.config, 'retriever_batch_size', 32) mined_count = 0 # Collect all unique candidate passages from the dataset all_passages = set() for item in dataset.data: all_passages.update(item.get('pos', [])) all_passages.update(item.get('neg', [])) all_passages = list(all_passages) if not all_passages: logger.warning("No candidate passages found in dataset for hard negative mining.") return # Encode all candidate passages try: passage_embeddings = retriever_model.encode( all_passages, batch_size=retriever_batch_size, convert_to_numpy=True ) except Exception as e: logger.error(f"Error encoding passages for hard negative mining: {e}") return # For each query, mine hard negatives for idx in range(len(dataset)): item = dataset.data[idx] query = item['query'] positives = set(item.get('pos', [])) existing_negs = set(item.get('neg', [])) try: query_emb = retriever_model.encode([query], batch_size=1, convert_to_numpy=True)[0] # Compute similarity scores = np.dot(passage_embeddings, query_emb) # Rank passages by similarity ranked_indices = np.argsort(scores)[::-1] new_negs = [] for i in ranked_indices: passage = all_passages[int(i)] # Convert to int for indexing if passage not in positives and passage not in existing_negs: new_negs.append(passage) if len(new_negs) >= num_hard_negatives: break if new_negs: item['neg'] = list(existing_negs) + new_negs mined_count += 1 logger.debug(f"Query idx {idx}: Added {len(new_negs)} hard negatives.") except Exception as e: logger.error(f"Error mining hard negatives for query idx {idx}: {e}") logger.info(f"Hard negative mining complete. Updated negatives for {mined_count} queries.") def save_model(self, save_path: str): """Save the trained model""" try: os.makedirs(save_path, exist_ok=True) # Save model model_to_save = self.model.module if hasattr(self.model, 'module') else self.model if hasattr(model_to_save, 'save_pretrained'): model_to_save.save_pretrained(save_path) # type: ignore else: # Fallback to torch.save torch.save(model_to_save.state_dict(), os.path.join(save_path, 'pytorch_model.bin')) # type: ignore # Save tokenizer if available if self.tokenizer and hasattr(self.tokenizer, 'save_pretrained'): self.tokenizer.save_pretrained(save_path) # Save config if hasattr(self.config, 'save'): self.config.save(os.path.join(save_path, 'training_config.json')) logger.info(f"Model saved to {save_path}") except Exception as e: logger.error(f"Error saving model: {e}") raise def export_for_inference(self, output_path: str): """Export model for efficient inference""" if is_main_process(): logger.info(f"Exporting model for inference to {output_path}") # Save model in inference mode self.model.eval() # Save model state torch.save({ 'model_state_dict': self.model.state_dict(), 'config': self.config.to_dict(), 'model_type': 'bge_reranker' }, output_path) # Also save in HuggingFace format if possible try: hf_path = output_path.replace('.pt', '_hf') self.save_model(hf_path) except Exception as e: logger.warning(f"Failed to save in HuggingFace format: {e}") logger.info(f"Model exported successfully")