import torch import torch.nn as nn from torch.utils.data import DataLoader from torch.optim import AdamW from transformers import AutoTokenizer import os import json import logging from tqdm import tqdm from typing import Dict, Optional, List, Tuple import numpy as np from torch.cuda.amp.autocast_mode import autocast from torch.cuda.amp.grad_scaler import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP import torch.distributed as dist from .trainer import BaseTrainer from models.bge_m3 import BGEM3Model from models.losses import InfoNCELoss, M3KnowledgeDistillationLoss, MultiTaskLoss from data.dataset import BGEM3Dataset # Optional hard negative mining try: from data.hard_negative_mining import ANCEHardNegativeMiner ANCE_AVAILABLE = True except ImportError: ANCE_AVAILABLE = False ANCEHardNegativeMiner = None from utils.checkpoint import CheckpointManager from utils.logging import get_logger from evaluation.metrics import compute_retrieval_metrics from config.m3_config import BGEM3Config # Import scheduler with fallback try: from transformers import get_linear_schedule_with_warmup except ImportError: get_linear_schedule_with_warmup = None logger = get_logger(__name__) class BGEM3Trainer(BaseTrainer): """ Trainer for BGE-M3 embedding model with all SOTA techniques Notes: - All parameters (config, model, tokenizer, etc.) should be passed from config-driven training scripts. - This class does not load config directly; pass config values from your main script for consistency. """ def __init__( self, config: BGEM3Config, model: Optional[BGEM3Model] = None, tokenizer: Optional[AutoTokenizer] = None, optimizer: Optional[torch.optim.Optimizer] = None, scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None, device: Optional[torch.device] = None ): # Store config first self.config = config # Initialize tokenizer first if tokenizer is None: self.tokenizer = AutoTokenizer.from_pretrained( config.model_name_or_path, cache_dir=config.cache_dir, trust_remote_code=True ) else: self.tokenizer = tokenizer # Initialize model if model is None: self.model = BGEM3Model( model_name_or_path=config.model_name_or_path, use_dense=config.use_dense, use_sparse=config.use_sparse, use_colbert=config.use_colbert, pooling_method=config.sentence_pooling_method, normalize_embeddings=config.normalize_embeddings, colbert_dim=config.colbert_dim, sparse_top_k=config.sparse_top_k, cache_dir=config.cache_dir ) else: self.model = model # Initialize optimizer if not provided if optimizer is None: optimizer = self._create_optimizer() # Initialize scheduler if not provided and get_linear_schedule_with_warmup is available if scheduler is None and get_linear_schedule_with_warmup is not None: # We'll create it later when we know the total training steps pass # Initialize parent class super().__init__(config, self.model, optimizer, scheduler, device) # Initialize losses self.infonce_loss = InfoNCELoss( temperature=config.temperature, reduction='mean' ) self.kd_loss = M3KnowledgeDistillationLoss( temperature=config.self_distill_temperature, alpha=config.self_distill_alpha, distill_types=['dense', 'sparse', 'colbert'] ) self.multi_task_loss = MultiTaskLoss( loss_weights={ 'dense': config.dense_weight, 'sparse': config.sparse_weight, 'colbert': config.colbert_weight, 'distillation': config.distillation_weight } ) # Initialize hard negative miner if config.use_hard_negatives and ANCE_AVAILABLE: self.hard_negative_miner = ANCEHardNegativeMiner( embedding_dim=self.model.config.hidden_size, # type: ignore[attr-defined] num_hard_negatives=config.num_hard_negatives, negative_range=config.ance_negative_range, margin=config.ance_margin, index_refresh_interval=config.hard_negative_mining_steps, use_gpu=str(self.device) == "cuda" ) self.hard_negative_miner.start_async_updates() else: if config.use_hard_negatives and not ANCE_AVAILABLE: logger.warning("Hard negative mining requested but ANCEHardNegativeMiner not available") self.hard_negative_miner = None logger.info(f"Initialized BGE-M3 trainer with config: {config.to_dict()}") def _create_optimizer(self) -> torch.optim.Optimizer: """Create optimizer with parameter groups""" # Separate parameters for different learning rates if needed param_groups = [ { 'params': [p for n, p in self.model.named_parameters() if 'colbert_linear' not in n and 'sparse_linear' not in n], 'lr': self.config.learning_rate } ] # Different LR for task-specific layers if self.config.use_colbert and hasattr(self.model, 'colbert_linear'): param_groups.append({ 'params': self.model.colbert_linear.parameters(), 'lr': self.config.learning_rate * 2 # Higher LR for new layers }) if self.config.use_sparse and hasattr(self.model, 'sparse_linear'): param_groups.append({ 'params': self.model.sparse_linear.parameters(), 'lr': self.config.learning_rate * 2 }) optimizer = AdamW( param_groups, eps=self.config.adam_epsilon, weight_decay=self.config.weight_decay, betas=(self.config.adam_beta1, self.config.adam_beta2) ) return optimizer def create_scheduler(self, num_training_steps: int): """Create learning rate scheduler""" if get_linear_schedule_with_warmup is None: logger.warning("transformers not available, using constant learning rate") return None num_warmup_steps = self.config.warmup_steps or int(num_training_steps * self.config.warmup_ratio) self.scheduler = get_linear_schedule_with_warmup( self.optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps ) return self.scheduler def compute_loss(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor: """Compute M3 loss with multiple objectives""" # Update hard negative mining if hook is available if hasattr(self, '_mining_hook'): self._mining_hook() # Detect batch format and handle accordingly if 'pos_input_ids' in batch and 'neg_input_ids' in batch: # This is triplet format from FastM3Dataset return self._compute_triplet_loss(batch) elif 'passage_input_ids' in batch: # This is standard embedding format return self._compute_embedding_loss(batch) else: raise ValueError(f"Unknown batch format. Keys: {list(batch.keys())}") def _compute_triplet_loss(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor: """Compute loss for triplet format data (FastM3Dataset)""" # Get query embeddings query_outputs = self.model.forward( input_ids=batch['query_input_ids'], attention_mask=batch['query_attention_mask'], return_dense=self.config.use_dense, return_sparse=self.config.use_sparse, return_colbert=self.config.use_colbert, return_dict=True ) # Get positive embeddings pos_outputs = self.model.forward( input_ids=batch['pos_input_ids'], attention_mask=batch['pos_attention_mask'], return_dense=self.config.use_dense, return_sparse=self.config.use_sparse, return_colbert=self.config.use_colbert, return_dict=True ) # Get negative embeddings neg_outputs = self.model.forward( input_ids=batch['neg_input_ids'], attention_mask=batch['neg_attention_mask'], return_dense=self.config.use_dense, return_sparse=self.config.use_sparse, return_colbert=self.config.use_colbert, return_dict=True ) losses = {} # Dense loss (InfoNCE with explicit pos/neg) if 'dense' in query_outputs: query_dense = query_outputs['dense'].float() # type: ignore[index] pos_dense = pos_outputs['dense'].float() # type: ignore[index] neg_dense = neg_outputs['dense'].float() # type: ignore[index] # For triplet format, reshape negative embeddings to expected 3D format # InfoNCE expects negative_embeddings: [batch_size, num_negatives, embedding_dim] # But we have single negatives: [batch_size, embedding_dim] # Reshape to [batch_size, 1, embedding_dim] neg_dense_3d = neg_dense.unsqueeze(1) # [batch_size, 1, embedding_dim] dense_loss = self.infonce_loss( query_dense, pos_dense, neg_dense_3d # Now properly shaped 3D tensor ) losses['dense'] = dense_loss # Sparse loss if self.config.use_sparse and 'sparse' in query_outputs: from models.losses import SparseLoss sparse_loss_fn = SparseLoss( flops_loss_weight=1e-3, sparse_reg_weight=1e-4 ) # Combine pos and neg for sparse loss combined_sparse = torch.cat([pos_outputs['sparse'], neg_outputs['sparse']], dim=0) # type: ignore[index] query_sparse_repeated = query_outputs['sparse'].repeat(2, 1) # type: ignore[index] labels = torch.cat([torch.ones(len(pos_outputs['sparse'])), # type: ignore[index] torch.zeros(len(neg_outputs['sparse']))], dim=0).to(query_sparse_repeated.device) # type: ignore[index] sparse_loss, _ = sparse_loss_fn( query_sparse_repeated, combined_sparse, labels ) losses['sparse'] = sparse_loss # ColBERT loss if self.config.use_colbert and 'colbert' in query_outputs: from models.losses import ColBERTLoss colbert_loss_fn = ColBERTLoss(temperature=self.config.temperature) # Combine pos and neg for ColBERT loss combined_colbert = torch.cat([pos_outputs['colbert'], neg_outputs['colbert']], dim=0) # type: ignore[index] query_colbert_repeated = query_outputs['colbert'].repeat(2, 1, 1) # type: ignore[index] labels = torch.cat([torch.ones(len(pos_outputs['colbert'])), # type: ignore[index] torch.zeros(len(neg_outputs['colbert']))], dim=0).to(query_colbert_repeated.device) # type: ignore[index] colbert_loss = colbert_loss_fn( query_colbert_repeated, combined_colbert, labels ) losses['colbert'] = colbert_loss # Combine losses total_loss, _ = self.multi_task_loss(losses) return total_loss def _compute_embedding_loss(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor: """Compute loss for standard embedding format data""" # Get query embeddings query_outputs = self.model.forward( input_ids=batch['query_input_ids'], attention_mask=batch['query_attention_mask'], return_dense=self.config.use_dense, return_sparse=self.config.use_sparse, return_colbert=self.config.use_colbert, return_dict=True ) # Get passage embeddings - handle both single and multiple passages if batch['passage_input_ids'].dim() == 3: # Multiple passages per query [batch, num_passages, seq_len] batch_size, num_passages, seq_len = batch['passage_input_ids'].shape passage_input_ids = batch['passage_input_ids'].view(batch_size * num_passages, seq_len) passage_attention_mask = batch['passage_attention_mask'].view(batch_size * num_passages, seq_len) multiple_passages = True else: # Single passage per query [batch, seq_len] passage_input_ids = batch['passage_input_ids'] passage_attention_mask = batch['passage_attention_mask'] batch_size = passage_input_ids.shape[0] num_passages = 1 multiple_passages = False passage_outputs = self.model.forward( input_ids=passage_input_ids, attention_mask=passage_attention_mask, return_dense=True, return_sparse=self.config.use_sparse, return_colbert=self.config.use_colbert, return_dict=True ) # Compute losses for different components losses = {} # Dense loss (InfoNCE) if 'dense' in query_outputs and 'dense' in passage_outputs: # Get embeddings and ensure they're on the same device query_dense = query_outputs['dense'].float() # type: ignore[index] # [batch_size, embedding_dim] passage_dense = passage_outputs['dense'].float() # type: ignore[index] # [batch_size * num_passages, embedding_dim] # Handle multiple passages per query differently for contrastive learning if multiple_passages: # Reshape passage embeddings back to [batch_size, num_passages, embedding_dim] passage_dense = passage_dense.view(batch_size, num_passages, -1) # For contrastive learning, we need to identify positive passages # Use the labels to find the first positive passage for each query labels = batch.get('labels') # [batch_size, num_passages] if labels is not None: # Find the first positive passage for each query positive_indices = [] negative_passages = [] for i in range(batch_size): query_labels = labels[i] # [num_passages] positive_mask = query_labels > 0 if positive_mask.any(): # Get first positive passage pos_idx = positive_mask.nonzero(as_tuple=True)[0][0].item() positive_indices.append(pos_idx) # Collect negative passages neg_mask = ~positive_mask if neg_mask.any(): neg_embeddings = passage_dense[i][neg_mask] # [num_negatives, embedding_dim] negative_passages.append(neg_embeddings) else: negative_passages.append(None) else: # No positive passage, use first passage as positive (fallback) positive_indices.append(0) negative_passages.append(passage_dense[i][1:]) # Extract positive embeddings positive_embeddings = torch.stack([ passage_dense[i, positive_indices[i]] for i in range(batch_size) ]) # [batch_size, embedding_dim] # Stack negative embeddings (if available) max_neg = max(neg.shape[0] if neg is not None else 0 for neg in negative_passages) if max_neg > 0: # Pad negative embeddings to same size padded_negatives = [] for neg in negative_passages: if neg is not None and neg.shape[0] > 0: if neg.shape[0] < max_neg: # Pad with zeros padding = torch.zeros(max_neg - neg.shape[0], neg.shape[1], device=neg.device) neg = torch.cat([neg, padding], dim=0) padded_negatives.append(neg) else: # No negatives, create zero tensor padded_negatives.append(torch.zeros(max_neg, passage_dense.shape[-1], device=passage_dense.device)) negative_embeddings = torch.stack(padded_negatives) # [batch_size, max_neg, embedding_dim] else: negative_embeddings = None # Compute InfoNCE loss with explicit positives and negatives dense_loss = self.infonce_loss( query_dense, positive_embeddings, negative_embeddings ) else: # No labels provided, use first passage as positive, rest as negatives positive_embeddings = passage_dense[:, 0, :] # [batch_size, embedding_dim] if num_passages > 1: negative_embeddings = passage_dense[:, 1:, :] # [batch_size, num_passages-1, embedding_dim] else: negative_embeddings = None dense_loss = self.infonce_loss( query_dense, positive_embeddings, negative_embeddings ) else: # Single passage per query - use standard in-batch negatives # InfoNCE will create contrastive pairs using in-batch negatives dense_loss = self.infonce_loss( query_dense, passage_dense, None # No explicit negatives, will use in-batch ) losses['dense'] = dense_loss # Sparse loss (if using sparse) if self.config.use_sparse and 'sparse' in query_outputs and 'sparse' in passage_outputs: from models.losses import SparseLoss sparse_loss_fn = SparseLoss( flops_loss_weight=1e-3, sparse_reg_weight=1e-4 ) query_sparse = query_outputs['sparse'] # type: ignore[index] # [batch_size, query_seq_len] passage_sparse = passage_outputs['sparse'] # type: ignore[index] # [batch_size * num_passages, doc_seq_len] # For sparse loss, we need to handle the contrastive learning differently # Since sparse embeddings have different sequence lengths, we'll use the same approach as dense if multiple_passages: # Use only the first positive passage for each query labels = batch.get('labels') if labels is not None: positive_indices = [] for i in range(batch_size): query_labels = labels[i] positive_mask = query_labels > 0 if positive_mask.any(): pos_idx = positive_mask.nonzero(as_tuple=True)[0][0].item() else: pos_idx = 0 positive_indices.append(pos_idx + i * num_passages) # Extract positive sparse embeddings positive_sparse = passage_sparse[positive_indices] # [batch_size, doc_seq_len] else: # Use first passage as positive positive_sparse = passage_sparse[::num_passages] # [batch_size, doc_seq_len] # Create contrastive labels for sparse loss contrastive_labels = torch.arange(batch_size, device=query_sparse.device, dtype=torch.long) sparse_loss, _ = sparse_loss_fn( query_sparse, positive_sparse, contrastive_labels ) else: # Single passage per query contrastive_labels = torch.arange(batch_size, device=query_sparse.device, dtype=torch.long) sparse_loss, _ = sparse_loss_fn( query_sparse, passage_sparse, contrastive_labels ) losses['sparse'] = sparse_loss # ColBERT loss (if using ColBERT) if self.config.use_colbert and 'colbert' in query_outputs and 'colbert' in passage_outputs: from models.losses import ColBERTLoss colbert_loss_fn = ColBERTLoss(temperature=self.config.temperature) query_colbert = query_outputs['colbert'] # type: ignore[index] # [batch_size, query_seq_len, dim] passage_colbert = passage_outputs['colbert'] # type: ignore[index] # [batch_size * num_passages, doc_seq_len, dim] # For ColBERT, we need pairwise interactions if multiple_passages: # Use only positive passages labels = batch.get('labels') if labels is not None: positive_indices = [] for i in range(batch_size): query_labels = labels[i] positive_mask = query_labels > 0 if positive_mask.any(): pos_idx = positive_mask.nonzero(as_tuple=True)[0][0].item() else: pos_idx = 0 positive_indices.append(pos_idx + i * num_passages) positive_colbert = passage_colbert[positive_indices] else: positive_colbert = passage_colbert[::num_passages] # Binary labels for ColBERT (1 for positive pairs) colbert_labels = torch.ones(batch_size, device=query_colbert.device) else: positive_colbert = passage_colbert colbert_labels = torch.ones(batch_size, device=query_colbert.device) colbert_loss = colbert_loss_fn( query_colbert, positive_colbert, colbert_labels ) losses['colbert'] = colbert_loss # Knowledge distillation loss (only if explicitly enabled) if (getattr(self.config, 'use_self_distill', False) and self.global_step >= getattr(self.config, 'self_distill_start_step', 0)): try: # Compute teacher scores (weighted combination) teacher_scores = self._compute_teacher_scores(query_outputs, passage_outputs) # type: ignore[arg-type] student_scores = self.model.compute_similarity( query_outputs, passage_outputs, mode='dense' ) kd_loss, _ = self.kd_loss( {'dense': student_scores}, {'dense': teacher_scores} ) losses['kd'] = kd_loss except Exception as e: logger.warning(f"Knowledge distillation failed: {e}") # Combine losses total_loss, _ = self.multi_task_loss(losses) return total_loss def _compute_teacher_scores( self, query_outputs: Dict[str, torch.Tensor], passage_outputs: Dict[str, torch.Tensor] ) -> torch.Tensor: """Compute teacher scores from multiple signals""" scores = [] weights = [] # Get original batch size for reshaping query_batch_size = query_outputs['dense'].shape[0] if 'dense' in query_outputs else 1 passage_batch_size = passage_outputs['dense'].shape[0] if 'dense' in passage_outputs else 1 # Check if we have multiple passages per query multiple_passages = passage_batch_size > query_batch_size if multiple_passages: # For multiple passages, we need to compute scores only for positive pairs # Use only the first passage for each query as a simplified teacher score num_passages = passage_batch_size // query_batch_size # Dense scores (use only first passage per query) if 'dense' in query_outputs and 'dense' in passage_outputs: query_dense = query_outputs['dense'] # [batch_size, dim] passage_dense = passage_outputs['dense'] # [batch_size * num_passages, dim] # Extract first passage for each query positive_passages = passage_dense[::num_passages] # [batch_size, dim] try: dense_scores = self.model.compute_similarity( {'dense': query_dense}, {'dense': positive_passages}, mode='dense' ) scores.append(dense_scores) weights.append(self.config.dense_weight) except Exception as e: logger.warning(f"Dense similarity computation failed: {e}") # Sparse scores (use only first passage per query) if 'sparse' in query_outputs and 'sparse' in passage_outputs: query_sparse = query_outputs['sparse'] # [batch_size, query_seq_len] passage_sparse = passage_outputs['sparse'] # [batch_size * num_passages, doc_seq_len] # Extract first passage for each query positive_passages = passage_sparse[::num_passages] # [batch_size, doc_seq_len] try: sparse_scores = self.model.compute_similarity( {'sparse': query_sparse}, {'sparse': positive_passages}, mode='sparse' ) scores.append(sparse_scores) weights.append(self.config.sparse_weight) except Exception as e: logger.warning(f"Sparse similarity computation failed: {e}") # ColBERT scores (use only first passage per query) if 'colbert' in query_outputs and 'colbert' in passage_outputs: query_colbert = query_outputs['colbert'] # [batch_size, query_seq_len, dim] passage_colbert = passage_outputs['colbert'] # [batch_size * num_passages, doc_seq_len, dim] # Extract first passage for each query positive_passages = passage_colbert[::num_passages] # [batch_size, doc_seq_len, dim] try: colbert_scores = self.model.compute_similarity( {'colbert': query_colbert}, {'colbert': positive_passages}, mode='colbert' ) scores.append(colbert_scores) weights.append(self.config.colbert_weight) except Exception as e: logger.warning(f"ColBERT similarity computation failed: {e}") else: # Single passage per query - original logic # Dense scores if 'dense' in query_outputs: try: dense_scores = self.model.compute_similarity( {'dense': query_outputs['dense']}, {'dense': passage_outputs['dense']}, mode='dense' ) scores.append(dense_scores) weights.append(self.config.dense_weight) except Exception as e: logger.warning(f"Dense similarity computation failed: {e}") # Sparse scores if 'sparse' in query_outputs: try: sparse_scores = self.model.compute_similarity( {'sparse': query_outputs['sparse']}, {'sparse': passage_outputs['sparse']}, mode='sparse' ) scores.append(sparse_scores) weights.append(self.config.sparse_weight) except Exception as e: logger.warning(f"Sparse similarity computation failed: {e}") # ColBERT scores if 'colbert' in query_outputs: try: colbert_scores = self.model.compute_similarity( {'colbert': query_outputs['colbert']}, {'colbert': passage_outputs['colbert']}, mode='colbert' ) scores.append(colbert_scores) weights.append(self.config.colbert_weight) except Exception as e: logger.warning(f"ColBERT similarity computation failed: {e}") # Weighted combination if scores: weights = torch.tensor(weights, device=scores[0].device) weights = weights / weights.sum() teacher_scores = torch.stack([w * s for w, s in zip(weights, scores)]).sum(dim=0) return teacher_scores # Fallback to zero scores if no valid similarity computation fallback_size = query_batch_size if not multiple_passages else query_batch_size return torch.zeros(fallback_size, device=query_outputs[list(query_outputs.keys())[0]].device) def evaluate_step(self, batch: Dict[str, torch.Tensor]) -> Dict[str, float]: """Evaluate a single batch""" # Get query embeddings query_outputs = self.model.forward( input_ids=batch['query_input_ids'], attention_mask=batch['query_attention_mask'], return_dense=True, return_dict=True ) # Get passage embeddings passage_outputs = self.model.forward( input_ids=batch['passage_input_ids'], attention_mask=batch['passage_attention_mask'], return_dense=True, return_dict=True ) # Compute loss loss = self.compute_loss(batch) # Compute similarity scores scores = self.model.compute_similarity( query_outputs, passage_outputs, mode='dense' ) # Basic metrics labels = batch.get('labels', torch.zeros(scores.shape[0])) if labels.dim() > 1: labels = labels.view(-1) predictions = (scores > 0.5).float() accuracy = (predictions == labels.float()).float().mean() return { 'loss': loss.item(), 'accuracy': accuracy.item(), 'mean_score': scores.mean().item() } def train( self, train_dataset: BGEM3Dataset, eval_dataset: Optional[BGEM3Dataset] = None, resume_from_checkpoint: Optional[str] = None ): """ Main training loop for M3Trainer with robust error handling, config-driven parameters, and progress tracking. """ try: train_dataloader = self._create_dataloader(train_dataset, is_train=True) eval_dataloader = self._create_dataloader(eval_dataset, is_train=False) if eval_dataset else None # Calculate total training steps num_update_steps_per_epoch = max(1, len(train_dataloader) // self.config.gradient_accumulation_steps) # type: ignore[arg-type] total_training_steps = num_update_steps_per_epoch * self.config.num_train_epochs # Create scheduler with total steps self.create_scheduler(total_training_steps) if resume_from_checkpoint: self.load_checkpoint(resume_from_checkpoint) logger.info(f"Starting training for {self.config.num_train_epochs} epochs") logger.info(f"Total training steps: {total_training_steps}") # Use parent's train method super().train(train_dataloader, eval_dataloader, self.config.num_train_epochs) # type: ignore[arg-type] except Exception as e: logger.error(f"M3 training failed: {e}") raise finally: # Cleanup if hasattr(self, 'hard_negative_miner') and self.hard_negative_miner: self.hard_negative_miner.stop_async_updates() def _create_dataloader( self, dataset: Optional[BGEM3Dataset], is_train: bool = True ) -> Optional[DataLoader]: """Create data loader""" if dataset is None: return None # Import both collate functions from data.dataset import collate_embedding_batch # Choose appropriate collate function based on dataset type collate_fn = collate_embedding_batch num_workers = self.config.dataloader_num_workers batch_size = self.config.per_device_train_batch_size if is_train else self.config.per_device_eval_batch_size # Setup distributed sampler if needed sampler = None if self.is_distributed: from torch.utils.data.distributed import DistributedSampler sampler = DistributedSampler(dataset, shuffle=is_train) dataloader = DataLoader( dataset, batch_size=batch_size, shuffle=is_train and sampler is None, sampler=sampler, num_workers=num_workers, pin_memory=self.config.dataloader_pin_memory, drop_last=is_train and len(dataset) > batch_size, # Only drop last if we have multiple batches collate_fn=collate_fn ) return dataloader def _train_epoch(self, dataloader: DataLoader, epoch: int) -> Dict[str, float]: """Train one epoch - use parent's method with hard negative mining integration""" # Add hard negative mining updates before calling parent's method if self.hard_negative_miner: # Set up a hook to update hard negatives during training def mining_hook(): if self.hard_negative_miner: self.hard_negative_miner.update_index_if_needed(self.global_step, self.model) # Store the hook for use in compute_loss self._mining_hook = mining_hook # Use parent's robust training loop return super().train_epoch(dataloader, epoch) def evaluate(self, dataloader: DataLoader) -> Dict[str, float]: """Evaluate the model""" self.model.eval() all_query_embeddings = [] all_passage_embeddings = [] all_labels = [] with torch.no_grad(): with tqdm(dataloader, desc="Evaluating") as tbar: for batch in tbar: try: batch = self._prepare_batch(batch) query_outputs = self.model.forward( input_ids=batch['query_input_ids'], attention_mask=batch['query_attention_mask'], return_dense=True, return_dict=True ) passage_outputs = self.model.forward( input_ids=batch['passage_input_ids'], attention_mask=batch['passage_attention_mask'], return_dense=True, return_dict=True ) all_query_embeddings.append(query_outputs['dense'].cpu()) # type: ignore[index] all_passage_embeddings.append(passage_outputs['dense'].cpu()) # type: ignore[index] all_labels.append(batch['labels'].cpu()) except Exception as e: logger.error(f"Error in M3 evaluation step: {e}") continue if not all_query_embeddings: logger.warning("No valid evaluation batches processed") return {'loss': float('inf'), 'accuracy': 0.0} # Concatenate all embeddings all_query_embeddings = torch.cat(all_query_embeddings, dim=0) all_passage_embeddings = torch.cat(all_passage_embeddings, dim=0) all_labels = torch.cat(all_labels, dim=0) # Compute metrics try: metrics = compute_retrieval_metrics( all_query_embeddings.numpy(), all_passage_embeddings.numpy(), all_labels.numpy(), k_values=[1, 3, 5, 10] ) except Exception as e: logger.error(f"Error computing retrieval metrics: {e}") metrics = {'loss': float('inf'), 'accuracy': 0.0} return metrics def save_model(self, save_path: str): """Save the trained model""" try: os.makedirs(save_path, exist_ok=True) # Save model model_to_save = self.model.module if hasattr(self.model, 'module') else self.model model_to_save.save_pretrained(save_path) # type: ignore[attr-defined] # Save tokenizer self.tokenizer.save_pretrained(save_path) # type: ignore[attr-defined] # Save config self.config.save(os.path.join(save_path, 'training_config.json')) logger.info(f"Model saved to {save_path}") except Exception as e: logger.error(f"Error saving model: {e}") raise def __del__(self): """Cleanup when trainer is destroyed""" if hasattr(self, 'hard_negative_miner') and self.hard_negative_miner: self.hard_negative_miner.stop_async_updates()