""" Loss functions for BGE fine-tuning """ import torch import torch.nn as nn import torch.nn.functional as F from typing import Dict, List, Optional, Tuple import numpy as np class InfoNCELoss(nn.Module): """ InfoNCE loss for contrastive learning As used in BGE-M3 paper with temperature τ=0.02 Notes: - All parameters (temperature, reduction, etc.) should be passed from config-driven training scripts. - This class does not load config directly; pass config values from your main script for consistency. """ def __init__(self, temperature: float = 0.02, reduction: str = 'mean'): super().__init__() self.temperature = temperature self.reduction = reduction self.cross_entropy = nn.CrossEntropyLoss(reduction=reduction) def forward( self, query_embeddings: torch.Tensor, positive_embeddings: torch.Tensor, negative_embeddings: Optional[torch.Tensor] = None ) -> torch.Tensor: """ Compute InfoNCE loss with proper normalization and numerical stability. Args: query_embeddings: [batch_size, embedding_dim] positive_embeddings: [batch_size, embedding_dim] negative_embeddings: [batch_size, num_negatives, embedding_dim] or None """ batch_size = query_embeddings.size(0) device = query_embeddings.device # Normalize embeddings (L2 normalization) query_embeddings = F.normalize(query_embeddings, p=2, dim=1) positive_embeddings = F.normalize(positive_embeddings, p=2, dim=1) # Compute positive scores (query-positive similarity) # Shape: [batch_size] pos_scores = torch.sum(query_embeddings * positive_embeddings, dim=1) pos_logits = pos_scores / self.temperature # Compute negative scores if provided if negative_embeddings is not None and negative_embeddings.numel() > 0: # Normalize negative embeddings negative_embeddings = F.normalize(negative_embeddings, p=2, dim=-1) # Compute negative scores # Shape: [batch_size, num_negatives] neg_scores = torch.matmul( query_embeddings.unsqueeze(1), negative_embeddings.transpose(1, 2) ).squeeze(1) neg_logits = neg_scores / self.temperature # Concatenate positive and negative logits # Shape: [batch_size, 1 + num_negatives] logits = torch.cat([pos_logits.unsqueeze(1), neg_logits], dim=1) # Labels: positive is always at index 0 labels = torch.zeros(batch_size, device=device, dtype=torch.long) else: # Use in-batch negatives as fallback # Compute similarity matrix for in-batch negatives similarity_matrix = torch.matmul(query_embeddings, positive_embeddings.t()) logits = similarity_matrix / self.temperature # Labels: positive pairs are on the diagonal labels = torch.arange(batch_size, device=device, dtype=torch.long) # Compute cross-entropy loss loss = self.cross_entropy(logits, labels) return loss class M3KnowledgeDistillationLoss(nn.Module): """ Knowledge distillation loss for BGE-M3's multi-functionality training Distills knowledge from dense, sparse, and multi-vector representations Notes: - All parameters (temperature, alpha, distill_types, etc.) should be passed from config-driven training scripts. - This class does not load config directly; pass config values from your main script for consistency. """ def __init__( self, temperature: float = 4.0, alpha: float = 0.5, distill_types: List[str] = ['dense', 'sparse', 'colbert'] ): super().__init__() self.temperature = temperature self.alpha = alpha self.distill_types = distill_types self.kl_div = nn.KLDivLoss(reduction='batchmean') def forward( self, student_outputs: Dict[str, torch.Tensor], teacher_outputs: Dict[str, torch.Tensor], labels: Optional[torch.Tensor] = None ) -> Tuple[torch.Tensor, Dict[str, float]]: """ Args: student_outputs: Dictionary with 'dense', 'sparse', 'colbert' outputs teacher_outputs: Dictionary with teacher model outputs labels: Optional labels for supervised loss Returns: total_loss: Combined distillation loss loss_dict: Dictionary of individual losses """ total_loss = torch.tensor(0.0, device=next(iter(student_outputs.values())).device if student_outputs else 'cpu') loss_dict = {} # Distillation loss for each representation type for repr_type in self.distill_types: if repr_type in student_outputs and repr_type in teacher_outputs: student_logits = student_outputs[repr_type] teacher_logits = teacher_outputs[repr_type] # Apply temperature scaling student_log_probs = F.log_softmax(student_logits / self.temperature, dim=-1) teacher_probs = F.softmax(teacher_logits / self.temperature, dim=-1) # KL divergence loss kl_loss = self.kl_div(student_log_probs, teacher_probs) * (self.temperature ** 2) total_loss = total_loss + self.alpha * kl_loss loss_dict[f'{repr_type}_distill_loss'] = kl_loss.detach().cpu().item() # Add supervised loss if labels provided if labels is not None and 'dense' in student_outputs: ce_loss = F.cross_entropy(student_outputs['dense'], labels) total_loss = total_loss + (1 - self.alpha) * ce_loss loss_dict['supervised_loss'] = ce_loss.detach().cpu().item() loss_dict['total_loss'] = total_loss.detach().cpu().item() return total_loss, loss_dict class ListwiseCrossEntropyLoss(nn.Module): """ Listwise cross-entropy loss for reranker training Notes: - All parameters (temperature, etc.) should be passed from config-driven training scripts. - This class does not load config directly; pass config values from your main script for consistency. """ def __init__(self, temperature: float = 1.0): super().__init__() self.temperature = temperature def forward(self, scores: torch.Tensor, labels: torch.Tensor, groups: List[int]) -> torch.Tensor: """ Compute listwise cross-entropy loss with proper normalization. Args: scores: Relevance scores [total_passages] labels: Binary relevance labels [total_passages] groups: List of group sizes (passages per query) """ total_loss = torch.tensor(0.0, device=scores.device, dtype=scores.dtype) offset = 0 for group_size in groups: if group_size == 0: continue # Get scores and labels for this group group_scores = scores[offset:offset + group_size] group_labels = labels[offset:offset + group_size] # Skip if no positive labels if group_labels.sum() == 0: offset += group_size continue # Apply temperature scaling scaled_scores = group_scores / self.temperature # Apply log_softmax for numerical stability log_probs = F.log_softmax(scaled_scores, dim=0) # Normalize labels to create a probability distribution label_probs = group_labels.float() / group_labels.sum() # Compute cross-entropy: -sum(p_true * log(p_pred)) loss = -(label_probs * log_probs).sum() total_loss = total_loss + loss offset += group_size # Average over number of groups num_valid_groups = sum(1 for g in groups if g > 0) if num_valid_groups > 0: return total_loss / num_valid_groups else: return total_loss class PairwiseMarginLoss(nn.Module): """ Pairwise margin ranking loss for reranker Notes: - All parameters (margin, etc.) should be passed from config-driven training scripts. - This class does not load config directly; pass config values from your main script for consistency. """ def __init__(self, margin: float = 1.0): super().__init__() self.margin = margin self.ranking_loss = nn.MarginRankingLoss(margin=margin) def forward(self, scores: torch.Tensor, labels: torch.Tensor, groups: List[int]) -> torch.Tensor: """ Args: scores: Relevance scores [total_passages] labels: Binary relevance labels [total_passages] groups: List of group sizes """ total_loss = torch.tensor(0.0, device=scores.device) num_pairs = 0 offset = 0 for group_size in groups: group_scores = scores[offset:offset + group_size] group_labels = labels[offset:offset + group_size] # Find positive and negative indices pos_indices = torch.where(group_labels == 1)[0] neg_indices = torch.where(group_labels == 0)[0] if len(pos_indices) > 0 and len(neg_indices) > 0: # Create all positive-negative pairs for pos_idx in pos_indices: for neg_idx in neg_indices: pos_score = group_scores[pos_idx].unsqueeze(0) neg_score = group_scores[neg_idx].unsqueeze(0) target = torch.ones(1, device=scores.device) loss = self.ranking_loss(pos_score, neg_score, target) total_loss = total_loss + loss num_pairs += 1 offset += group_size return total_loss / max(num_pairs, 1) class DynamicListwiseDistillationLoss(nn.Module): """ Dynamic listwise distillation loss for RocketQAv2-style joint training Notes: - All parameters (temperature, dynamic_weight, etc.) should be passed from config-driven training scripts. - This class does not load config directly; pass config values from your main script for consistency. """ def __init__(self, temperature: float = 3.0, dynamic_weight: bool = True): super().__init__() self.temperature = temperature self.dynamic_weight = dynamic_weight self.kl_div = nn.KLDivLoss(reduction='none') def forward( self, student_scores: torch.Tensor, teacher_scores: torch.Tensor, groups: List[int], confidence_scores: Optional[torch.Tensor] = None ) -> torch.Tensor: """ Args: student_scores: Student model scores [total_passages] teacher_scores: Teacher model scores [total_passages] groups: List of group sizes confidence_scores: Optional confidence for dynamic weighting """ total_loss = torch.tensor(0.0, device=student_scores.device) offset = 0 for i, group_size in enumerate(groups): student_group = student_scores[offset:offset + group_size] / self.temperature teacher_group = teacher_scores[offset:offset + group_size] / self.temperature # Convert to distributions student_probs = F.log_softmax(student_group, dim=0) teacher_probs = F.softmax(teacher_group, dim=0) # Compute KL divergence kl_loss = self.kl_div(student_probs, teacher_probs).sum() # Apply dynamic weighting based on confidence if self.dynamic_weight and confidence_scores is not None: weight = confidence_scores[i] kl_loss = kl_loss * weight total_loss = total_loss + kl_loss * (self.temperature ** 2) offset += group_size return total_loss / len(groups) class MultiTaskLoss(nn.Module): """ Multi-task loss for combining multiple objectives Notes: - All parameters (loss_weights, adaptive_weights, etc.) should be passed from config-driven training scripts. - This class does not load config directly; pass config values from your main script for consistency. """ def __init__( self, loss_weights: Optional[Dict[str, float]] = None, adaptive_weights: bool = False ): super().__init__() self.loss_weights = loss_weights or { 'dense': 1.0, 'sparse': 0.3, 'colbert': 0.5, 'distillation': 1.0 } self.adaptive_weights = adaptive_weights if adaptive_weights: # Learnable loss weights (uncertainty weighting) self.log_vars = nn.Parameter(torch.zeros(len(self.loss_weights))) def forward(self, losses: Dict[str, torch.Tensor]) -> Tuple[torch.Tensor, Dict[str, float]]: """ Args: losses: Dictionary of individual losses Returns: total_loss: Combined loss loss_dict: Dictionary of individual loss values """ total_loss = torch.tensor(0.0, device=next(iter(losses.values())).device if losses else 'cpu') loss_dict = {} if self.adaptive_weights: # Uncertainty-based weighting for i, (loss_name, loss_value) in enumerate(losses.items()): if loss_name in self.loss_weights: precision = torch.exp(-self.log_vars[i]) weighted_loss = precision * loss_value + self.log_vars[i] total_loss = total_loss + weighted_loss loss_dict[f'{loss_name}_loss'] = loss_value.detach().cpu().item() if isinstance(loss_value, torch.Tensor) else float(loss_value) loss_dict[f'{loss_name}_weight'] = precision.detach().cpu().item() if isinstance(precision, torch.Tensor) else float(precision) else: # Fixed weighting for loss_name, loss_value in losses.items(): if loss_name in self.loss_weights: weight = self.loss_weights[loss_name] weighted_loss = weight * loss_value total_loss = total_loss + weighted_loss loss_dict[f'{loss_name}_loss'] = loss_value.detach().cpu().item() if isinstance(loss_value, torch.Tensor) else float(loss_value) loss_dict[f'{loss_name}_weight'] = weight loss_dict['total_loss'] = total_loss.detach().cpu().item() if isinstance(total_loss, torch.Tensor) else float(total_loss) return total_loss, loss_dict class SparseLoss(nn.Module): """ SPLADE+ sparse loss with log1p(ReLU(.)) and FLOPS regularization. For BGE-M3's token-position based sparse representations. Args: flops_loss_weight: Weight for FLOPS regularization sparse_reg_weight: Weight for sparsity regularization temperature: Temperature for similarity computation """ def __init__( self, flops_loss_weight: float = 1e-3, sparse_reg_weight: float = 1e-4, temperature: float = 0.02 ): super().__init__() self.flops_loss_weight = flops_loss_weight self.sparse_reg_weight = sparse_reg_weight self.temperature = temperature self.cross_entropy = nn.CrossEntropyLoss() def forward( self, query_sparse: torch.Tensor, doc_sparse: torch.Tensor, labels: torch.Tensor ) -> Tuple[torch.Tensor, Dict[str, float]]: """ Args: query_sparse: [batch, query_seq_len] - Query sparse embeddings (token-position based) doc_sparse: [batch, doc_seq_len] - Document sparse embeddings (token-position based) labels: [batch] - Labels for contrastive learning Returns: Tuple of (loss, loss_dict) """ batch_size = query_sparse.size(0) device = query_sparse.device # SPLADE+ uses log1p(ReLU(.)) query_splade = torch.log1p(F.relu(query_sparse)) doc_splade = torch.log1p(F.relu(doc_sparse)) # For BGE-M3's token-position sparse embeddings, compute similarity using aggregated scores # Since we have different sequence lengths, we need to aggregate to comparable representations # Sum pooling to get document-level sparse scores query_scores = query_splade.sum(dim=1, keepdim=True) # [batch, 1] doc_scores = doc_splade.sum(dim=1, keepdim=True) # [batch, 1] # Compute similarity matrix for contrastive learning # Use broadcasting to create [batch, batch] similarity matrix similarity_matrix = query_scores * doc_scores.t() # [batch, batch] # Alternative: use cosine similarity approach # Normalize the aggregated scores query_norm = F.normalize(query_scores, p=2, dim=1) doc_norm = F.normalize(doc_scores, p=2, dim=1) similarity_matrix = torch.matmul(query_norm, doc_norm.t()) # Scale by temperature for contrastive loss logits = similarity_matrix / self.temperature # Labels for in-batch contrastive learning (diagonal elements are positive pairs) contrastive_labels = torch.arange(batch_size, device=device, dtype=torch.long) # Contrastive loss main_loss = self.cross_entropy(logits, contrastive_labels) # FLOPS regularization (encourage sparsity) flops = (query_splade.sum(dim=-1) + doc_splade.sum(dim=-1)).mean() # L1 sparsity regularization sparsity_loss = (query_splade.sum() + doc_splade.sum()) / (query_splade.numel() + doc_splade.numel()) # Combine losses total_loss = main_loss + self.flops_loss_weight * flops + self.sparse_reg_weight * sparsity_loss # Ensure all losses are tensors if not isinstance(total_loss, torch.Tensor): total_loss = torch.tensor(total_loss, device=device) if not isinstance(main_loss, torch.Tensor): main_loss = torch.tensor(main_loss, device=device) if not isinstance(flops, torch.Tensor): flops = torch.tensor(flops, device=device) if not isinstance(sparsity_loss, torch.Tensor): sparsity_loss = torch.tensor(sparsity_loss, device=device) loss_dict = { 'main_loss': main_loss.detach().cpu().item(), 'flops': flops.detach().cpu().item(), 'sparsity_loss': sparsity_loss.detach().cpu().item(), 'total_loss': total_loss.detach().cpu().item() } return total_loss, loss_dict class ColBERTLoss(nn.Module): """ ColBERT late interaction (maxsim) loss. Args: temperature: Softmax temperature """ def __init__(self, temperature: float = 0.02): super().__init__() self.temperature = temperature def forward( self, query_embeds: torch.Tensor, doc_embeds: torch.Tensor, labels: torch.Tensor ) -> torch.Tensor: """ Args: query_embeds: [batch, seq_len, dim] doc_embeds: [batch, seq_len, dim] labels: [batch] Returns: Loss value """ # Late interaction: maxsim over doc tokens for each query token # [batch, query_len, doc_len] sim_matrix = torch.einsum('bqd,bkd->bqk', query_embeds, doc_embeds) / self.temperature maxsim = sim_matrix.max(dim=2)[0] # [batch, query_len] # Aggregate over query tokens (mean) scores = maxsim.mean(dim=1) # [batch] # Binary cross-entropy loss loss = F.binary_cross_entropy_with_logits(scores, labels.float()) return loss