import torch import torch.nn as nn from torch.utils.data import DataLoader from torch.optim import AdamW from typing import Dict, Optional, List, Tuple import logging from tqdm import tqdm import numpy as np import os from .trainer import BaseTrainer from models.bge_m3 import BGEM3Model from models.bge_reranker import BGERerankerModel from models.losses import ( InfoNCELoss, ListwiseCrossEntropyLoss, DynamicListwiseDistillationLoss, MultiTaskLoss ) from data.dataset import JointDataset from config.base_config import BaseConfig from utils.distributed import is_main_process, get_world_size, reduce_dict from utils.logging import get_logger # Import autocast and GradScaler with fallback from torch.cuda.amp.autocast_mode import autocast from torch.cuda.amp.grad_scaler import GradScaler # Scheduler import try: from transformers import get_linear_schedule_with_warmup except ImportError: get_linear_schedule_with_warmup = None logger = get_logger(__name__) class JointTrainer(BaseTrainer): """ Joint trainer for BGE-M3 and BGE-Reranker using RocketQAv2 approach Notes: - All parameters (config, retriever, reranker, optimizers, device, etc.) should be passed from config-driven training scripts. - This class does not load config directly; pass config values from your main script for consistency. """ def __init__( self, config: BaseConfig, retriever: BGEM3Model, reranker: BGERerankerModel, retriever_optimizer: Optional[torch.optim.Optimizer] = None, reranker_optimizer: Optional[torch.optim.Optimizer] = None, device: Optional[torch.device] = None ): # Create default optimizer if none provided if retriever_optimizer is None: retriever_optimizer = AdamW( retriever.parameters(), lr=config.learning_rate, weight_decay=config.weight_decay ) # We'll use the retriever as the main model for the base class super().__init__(config, retriever, retriever_optimizer, device=device) self.retriever = self.model # Alias for clarity self.reranker = reranker self.reranker.to(self.device) # Setup reranker optimizer if not provided if reranker_optimizer is None: self.reranker_optimizer = AdamW( self.reranker.parameters(), lr=config.learning_rate * 2, # Higher LR for reranker weight_decay=config.weight_decay ) else: self.reranker_optimizer = reranker_optimizer # Initialize losses self.retriever_loss_fn = InfoNCELoss( temperature=config.temperature if hasattr(config, 'temperature') else 0.02, use_in_batch_negatives=True # type: ignore[call-arg] ) self.reranker_loss_fn = ListwiseCrossEntropyLoss( temperature=1.0, label_smoothing=0.0 # type: ignore[call-arg] ) self.distillation_loss_fn = DynamicListwiseDistillationLoss( temperature=1.0, alpha=0.5 # type: ignore[call-arg] ) # Joint training specific parameters self.update_retriever_interval = getattr(config, 'update_retriever_steps', 100) self.joint_loss_weight = getattr(config, 'joint_loss_weight', 0.3) self.use_dynamic_distillation = getattr(config, 'use_dynamic_listwise_distillation', True) # Tracking metrics self.retriever_metrics = [] self.reranker_metrics = [] def compute_loss(self, batch: Tuple[Dict, Dict]) -> torch.Tensor: """ Compute joint loss for both retriever and reranker Args: batch: Tuple of (retriever_batch, reranker_batch) """ retriever_batch, reranker_batch = batch # 1. Compute retriever loss retriever_loss = self._compute_retriever_loss(retriever_batch) # 2. Compute reranker loss reranker_loss = self._compute_reranker_loss(reranker_batch) # 3. Compute distillation losses if enabled if self.use_dynamic_distillation: # Get scores from both models with torch.no_grad(): retriever_scores = self._get_retriever_scores(retriever_batch) reranker_scores = self._get_reranker_scores(reranker_batch) # Compute mutual distillation retriever_distill_loss, reranker_distill_loss = self.distillation_loss_fn( retriever_scores, reranker_scores, labels=reranker_batch.get('labels') ) # Add distillation to main losses retriever_loss = retriever_loss + self.joint_loss_weight * retriever_distill_loss reranker_loss = reranker_loss + self.joint_loss_weight * reranker_distill_loss # Combine losses (we'll handle separate backward passes) self.current_retriever_loss = retriever_loss self.current_reranker_loss = reranker_loss # Return combined loss for logging return retriever_loss + reranker_loss def _compute_retriever_loss(self, batch: Dict) -> torch.Tensor: """Compute retriever (bi-encoder) loss""" # Get embeddings query_outputs = self.retriever( input_ids=batch['query_input_ids'], attention_mask=batch['query_attention_mask'], return_dense=True ) passage_outputs = self.retriever( input_ids=batch['passage_input_ids'], attention_mask=batch['passage_attention_mask'], return_dense=True ) # Compute contrastive loss loss = self.retriever_loss_fn( query_outputs['dense'], passage_outputs['dense'], labels=batch.get('labels') ) return loss def _compute_reranker_loss(self, batch: Dict) -> torch.Tensor: """Compute reranker (cross-encoder) loss""" outputs = self.reranker( input_ids=batch['input_ids'], attention_mask=batch['attention_mask'] ) logits = outputs['logits'] labels = batch['labels'] # Reshape for listwise loss if needed if logits.dim() == 1: # Binary classification loss = nn.functional.binary_cross_entropy_with_logits(logits, labels.float()) else: # Listwise ranking loss = self.reranker_loss_fn(logits, labels) return loss def _get_retriever_scores(self, batch: Dict) -> torch.Tensor: """Get retriever similarity scores""" with torch.no_grad(): query_outputs = self.retriever( input_ids=batch['query_input_ids'], attention_mask=batch['query_attention_mask'], return_dense=True ) passage_outputs = self.retriever( input_ids=batch['passage_input_ids'], attention_mask=batch['passage_attention_mask'], return_dense=True ) # Compute similarities scores = torch.matmul( query_outputs['dense'], passage_outputs['dense'].transpose(-1, -2) ) return scores def _get_reranker_scores(self, batch: Dict) -> torch.Tensor: """Get reranker scores""" outputs = self.reranker( input_ids=batch['input_ids'], attention_mask=batch['attention_mask'] ) return outputs['logits'] def _train_epoch(self, dataloader: DataLoader, epoch: int) -> Dict[str, float]: """Train both models for one epoch""" self.retriever.train() self.reranker.train() epoch_metrics = { 'retriever_loss': 0.0, 'reranker_loss': 0.0, 'total_loss': 0.0 } num_batches = 0 # Set epoch for distributed sampler if hasattr(dataloader, 'sampler') and hasattr(dataloader.sampler, 'set_epoch'): dataloader.sampler.set_epoch(epoch) # type: ignore[attr-defined] progress_bar = tqdm( dataloader, desc=f"Joint Training Epoch {epoch + 1}", disable=not is_main_process() ) for step, batch in enumerate(progress_bar): # Prepare batch retriever_batch, reranker_batch = batch retriever_batch = self._prepare_batch(retriever_batch) reranker_batch = self._prepare_batch(reranker_batch) batch = (retriever_batch, reranker_batch) # Forward pass total_loss = self.compute_loss(batch) # Scale losses for gradient accumulation retriever_loss = self.current_retriever_loss / self.config.gradient_accumulation_steps reranker_loss = self.current_reranker_loss / self.config.gradient_accumulation_steps # Backward passes (separate for each model) if self.scaler: # Retriever backward self.scaler.scale(retriever_loss).backward(retain_graph=True) # type: ignore[attr-defined] # Reranker backward self.scaler.scale(reranker_loss).backward() # type: ignore[attr-defined] else: retriever_loss.backward(retain_graph=True) reranker_loss.backward() # Update weights if (step + 1) % self.config.gradient_accumulation_steps == 0: # Gradient clipping if self.scaler: self.scaler.unscale_(self.optimizer) self.scaler.unscale_(self.reranker_optimizer) torch.nn.utils.clip_grad_norm_( # type: ignore[attr-defined] self.retriever.parameters(), self.config.max_grad_norm ) torch.nn.utils.clip_grad_norm_( # type: ignore[attr-defined] self.reranker.parameters(), self.config.max_grad_norm ) # Optimizer steps if self.scaler: self.scaler.step(self.optimizer) self.scaler.step(self.reranker_optimizer) self.scaler.update() else: self.optimizer.step() self.reranker_optimizer.step() # Scheduler steps if self.scheduler: self.scheduler.step() # Note: Add reranker scheduler if needed # Zero gradients self.optimizer.zero_grad() self.reranker_optimizer.zero_grad() self.global_step += 1 # Update hard negatives periodically if self.global_step % self.update_retriever_interval == 0: self._update_hard_negatives() # Logging if self.global_step % self.config.logging_steps == 0: step_metrics = { 'retriever_loss': retriever_loss.item() * self.config.gradient_accumulation_steps, 'reranker_loss': reranker_loss.item() * self.config.gradient_accumulation_steps, 'total_loss': ( retriever_loss.item() + reranker_loss.item()) * self.config.gradient_accumulation_steps, 'retriever_lr': self.optimizer.param_groups[0]['lr'], 'reranker_lr': self.reranker_optimizer.param_groups[0]['lr'] } if self.is_distributed: step_metrics = reduce_dict(step_metrics) progress_bar.set_postfix({ 'ret_loss': f"{step_metrics['retriever_loss']:.4f}", 'rank_loss': f"{step_metrics['reranker_loss']:.4f}" }) if is_main_process(): for key, value in step_metrics.items(): self.tb_logger.log_scalar(f"joint/{key}", value, self.global_step) # Flush logs periodically if self.global_step % (self.config.logging_steps * 10) == 0: self.tb_logger.flush() # Accumulate metrics epoch_metrics['retriever_loss'] += retriever_loss.item() epoch_metrics['reranker_loss'] += reranker_loss.item() epoch_metrics['total_loss'] += (retriever_loss.item() + reranker_loss.item()) num_batches += 1 # Average metrics for key in epoch_metrics: epoch_metrics[key] = epoch_metrics[key] / num_batches * self.config.gradient_accumulation_steps if self.is_distributed: epoch_metrics = reduce_dict(epoch_metrics) return epoch_metrics def evaluate_step(self, batch: Tuple[Dict, Dict]) -> Dict[str, float]: """Evaluate both models on a batch""" retriever_batch, reranker_batch = batch metrics = {} # Evaluate retriever with torch.no_grad(): retriever_loss = self._compute_retriever_loss(retriever_batch) metrics['retriever_loss'] = retriever_loss.item() # Compute retrieval metrics retriever_scores = self._get_retriever_scores(retriever_batch) labels = retriever_batch['labels'] # Simple accuracy: is the positive passage ranked first? top_indices = retriever_scores.argmax(dim=-1) correct = (top_indices == labels.argmax(dim=-1)).float().mean() metrics['retriever_accuracy'] = correct.item() # Evaluate reranker with torch.no_grad(): reranker_loss = self._compute_reranker_loss(reranker_batch) metrics['reranker_loss'] = reranker_loss.item() # Compute reranking metrics reranker_scores = self._get_reranker_scores(reranker_batch) labels = reranker_batch['labels'] if reranker_scores.dim() == 1: # Binary accuracy predictions = (reranker_scores > 0).float() correct = (predictions == labels).float().mean() else: # Ranking accuracy top_indices = reranker_scores.argmax(dim=-1) correct = (top_indices == labels.argmax(dim=-1)).float().mean() metrics['reranker_accuracy'] = correct.item() metrics['total_loss'] = metrics['retriever_loss'] + metrics['reranker_loss'] return metrics def _update_hard_negatives(self): """Update hard negatives using current models (simplified version)""" # This would implement the dynamic hard negative mining # For now, just log that we would update if is_main_process(): logger.info(f"Would update hard negatives at step {self.global_step}") def save_models(self, output_dir: str): """Save both models""" if not is_main_process(): return # Save retriever retriever_path = os.path.join(output_dir, "retriever") os.makedirs(retriever_path, exist_ok=True) if hasattr(self.retriever, 'module'): self.retriever.module.save_pretrained(retriever_path) # type: ignore[attr-defined] else: self.retriever.save_pretrained(retriever_path) # type: ignore[attr-defined] # Save reranker reranker_path = os.path.join(output_dir, "reranker") os.makedirs(reranker_path, exist_ok=True) if hasattr(self.reranker, 'module'): self.reranker.module.save_pretrained(reranker_path) # type: ignore[attr-defined] else: self.reranker.save_pretrained(reranker_path) # type: ignore[attr-defined] logger.info(f"Saved models to {output_dir}")