bge_finetune/models/losses.py
2025-07-22 16:55:25 +08:00

508 lines
20 KiB
Python

"""
Loss functions for BGE fine-tuning
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Dict, List, Optional, Tuple
import numpy as np
class InfoNCELoss(nn.Module):
"""
InfoNCE loss for contrastive learning
As used in BGE-M3 paper with temperature τ=0.02
Notes:
- All parameters (temperature, reduction, etc.) should be passed from config-driven training scripts.
- This class does not load config directly; pass config values from your main script for consistency.
"""
def __init__(self, temperature: float = 0.02, reduction: str = 'mean'):
super().__init__()
self.temperature = temperature
self.reduction = reduction
self.cross_entropy = nn.CrossEntropyLoss(reduction=reduction)
def forward(
self,
query_embeddings: torch.Tensor,
positive_embeddings: torch.Tensor,
negative_embeddings: Optional[torch.Tensor] = None
) -> torch.Tensor:
"""
Compute InfoNCE loss with proper normalization and numerical stability.
Args:
query_embeddings: [batch_size, embedding_dim]
positive_embeddings: [batch_size, embedding_dim]
negative_embeddings: [batch_size, num_negatives, embedding_dim] or None
"""
batch_size = query_embeddings.size(0)
device = query_embeddings.device
# Normalize embeddings (L2 normalization)
query_embeddings = F.normalize(query_embeddings, p=2, dim=1)
positive_embeddings = F.normalize(positive_embeddings, p=2, dim=1)
# Compute positive scores (query-positive similarity)
# Shape: [batch_size]
pos_scores = torch.sum(query_embeddings * positive_embeddings, dim=1)
pos_logits = pos_scores / self.temperature
# Compute negative scores if provided
if negative_embeddings is not None and negative_embeddings.numel() > 0:
# Normalize negative embeddings
negative_embeddings = F.normalize(negative_embeddings, p=2, dim=-1)
# Compute negative scores
# Shape: [batch_size, num_negatives]
neg_scores = torch.matmul(
query_embeddings.unsqueeze(1),
negative_embeddings.transpose(1, 2)
).squeeze(1)
neg_logits = neg_scores / self.temperature
# Concatenate positive and negative logits
# Shape: [batch_size, 1 + num_negatives]
logits = torch.cat([pos_logits.unsqueeze(1), neg_logits], dim=1)
# Labels: positive is always at index 0
labels = torch.zeros(batch_size, device=device, dtype=torch.long)
else:
# Use in-batch negatives as fallback
# Compute similarity matrix for in-batch negatives
similarity_matrix = torch.matmul(query_embeddings, positive_embeddings.t())
logits = similarity_matrix / self.temperature
# Labels: positive pairs are on the diagonal
labels = torch.arange(batch_size, device=device, dtype=torch.long)
# Compute cross-entropy loss
loss = self.cross_entropy(logits, labels)
return loss
class M3KnowledgeDistillationLoss(nn.Module):
"""
Knowledge distillation loss for BGE-M3's multi-functionality training
Distills knowledge from dense, sparse, and multi-vector representations
Notes:
- All parameters (temperature, alpha, distill_types, etc.) should be passed from config-driven training scripts.
- This class does not load config directly; pass config values from your main script for consistency.
"""
def __init__(
self,
temperature: float = 4.0,
alpha: float = 0.5,
distill_types: List[str] = ['dense', 'sparse', 'colbert']
):
super().__init__()
self.temperature = temperature
self.alpha = alpha
self.distill_types = distill_types
self.kl_div = nn.KLDivLoss(reduction='batchmean')
def forward(
self,
student_outputs: Dict[str, torch.Tensor],
teacher_outputs: Dict[str, torch.Tensor],
labels: Optional[torch.Tensor] = None
) -> Tuple[torch.Tensor, Dict[str, float]]:
"""
Args:
student_outputs: Dictionary with 'dense', 'sparse', 'colbert' outputs
teacher_outputs: Dictionary with teacher model outputs
labels: Optional labels for supervised loss
Returns:
total_loss: Combined distillation loss
loss_dict: Dictionary of individual losses
"""
total_loss = torch.tensor(0.0, device=next(iter(student_outputs.values())).device if student_outputs else 'cpu')
loss_dict = {}
# Distillation loss for each representation type
for repr_type in self.distill_types:
if repr_type in student_outputs and repr_type in teacher_outputs:
student_logits = student_outputs[repr_type]
teacher_logits = teacher_outputs[repr_type]
# Apply temperature scaling
student_log_probs = F.log_softmax(student_logits / self.temperature, dim=-1)
teacher_probs = F.softmax(teacher_logits / self.temperature, dim=-1)
# KL divergence loss
kl_loss = self.kl_div(student_log_probs, teacher_probs) * (self.temperature ** 2)
total_loss = total_loss + self.alpha * kl_loss
loss_dict[f'{repr_type}_distill_loss'] = kl_loss.detach().cpu().item()
# Add supervised loss if labels provided
if labels is not None and 'dense' in student_outputs:
ce_loss = F.cross_entropy(student_outputs['dense'], labels)
total_loss = total_loss + (1 - self.alpha) * ce_loss
loss_dict['supervised_loss'] = ce_loss.detach().cpu().item()
loss_dict['total_loss'] = total_loss.detach().cpu().item()
return total_loss, loss_dict
class ListwiseCrossEntropyLoss(nn.Module):
"""
Listwise cross-entropy loss for reranker training
Notes:
- All parameters (temperature, etc.) should be passed from config-driven training scripts.
- This class does not load config directly; pass config values from your main script for consistency.
"""
def __init__(self, temperature: float = 1.0):
super().__init__()
self.temperature = temperature
def forward(self, scores: torch.Tensor, labels: torch.Tensor, groups: List[int]) -> torch.Tensor:
"""
Compute listwise cross-entropy loss with proper normalization.
Args:
scores: Relevance scores [total_passages]
labels: Binary relevance labels [total_passages]
groups: List of group sizes (passages per query)
"""
total_loss = torch.tensor(0.0, device=scores.device, dtype=scores.dtype)
offset = 0
for group_size in groups:
if group_size == 0:
continue
# Get scores and labels for this group
group_scores = scores[offset:offset + group_size]
group_labels = labels[offset:offset + group_size]
# Skip if no positive labels
if group_labels.sum() == 0:
offset += group_size
continue
# Apply temperature scaling
scaled_scores = group_scores / self.temperature
# Apply log_softmax for numerical stability
log_probs = F.log_softmax(scaled_scores, dim=0)
# Normalize labels to create a probability distribution
label_probs = group_labels.float() / group_labels.sum()
# Compute cross-entropy: -sum(p_true * log(p_pred))
loss = -(label_probs * log_probs).sum()
total_loss = total_loss + loss
offset += group_size
# Average over number of groups
num_valid_groups = sum(1 for g in groups if g > 0)
if num_valid_groups > 0:
return total_loss / num_valid_groups
else:
return total_loss
class PairwiseMarginLoss(nn.Module):
"""
Pairwise margin ranking loss for reranker
Notes:
- All parameters (margin, etc.) should be passed from config-driven training scripts.
- This class does not load config directly; pass config values from your main script for consistency.
"""
def __init__(self, margin: float = 1.0):
super().__init__()
self.margin = margin
self.ranking_loss = nn.MarginRankingLoss(margin=margin)
def forward(self, scores: torch.Tensor, labels: torch.Tensor, groups: List[int]) -> torch.Tensor:
"""
Args:
scores: Relevance scores [total_passages]
labels: Binary relevance labels [total_passages]
groups: List of group sizes
"""
total_loss = torch.tensor(0.0, device=scores.device)
num_pairs = 0
offset = 0
for group_size in groups:
group_scores = scores[offset:offset + group_size]
group_labels = labels[offset:offset + group_size]
# Find positive and negative indices
pos_indices = torch.where(group_labels == 1)[0]
neg_indices = torch.where(group_labels == 0)[0]
if len(pos_indices) > 0 and len(neg_indices) > 0:
# Create all positive-negative pairs
for pos_idx in pos_indices:
for neg_idx in neg_indices:
pos_score = group_scores[pos_idx].unsqueeze(0)
neg_score = group_scores[neg_idx].unsqueeze(0)
target = torch.ones(1, device=scores.device)
loss = self.ranking_loss(pos_score, neg_score, target)
total_loss = total_loss + loss
num_pairs += 1
offset += group_size
return total_loss / max(num_pairs, 1)
class DynamicListwiseDistillationLoss(nn.Module):
"""
Dynamic listwise distillation loss for RocketQAv2-style joint training
Notes:
- All parameters (temperature, dynamic_weight, etc.) should be passed from config-driven training scripts.
- This class does not load config directly; pass config values from your main script for consistency.
"""
def __init__(self, temperature: float = 3.0, dynamic_weight: bool = True):
super().__init__()
self.temperature = temperature
self.dynamic_weight = dynamic_weight
self.kl_div = nn.KLDivLoss(reduction='none')
def forward(
self,
student_scores: torch.Tensor,
teacher_scores: torch.Tensor,
groups: List[int],
confidence_scores: Optional[torch.Tensor] = None
) -> torch.Tensor:
"""
Args:
student_scores: Student model scores [total_passages]
teacher_scores: Teacher model scores [total_passages]
groups: List of group sizes
confidence_scores: Optional confidence for dynamic weighting
"""
total_loss = torch.tensor(0.0, device=student_scores.device)
offset = 0
for i, group_size in enumerate(groups):
student_group = student_scores[offset:offset + group_size] / self.temperature
teacher_group = teacher_scores[offset:offset + group_size] / self.temperature
# Convert to distributions
student_probs = F.log_softmax(student_group, dim=0)
teacher_probs = F.softmax(teacher_group, dim=0)
# Compute KL divergence
kl_loss = self.kl_div(student_probs, teacher_probs).sum()
# Apply dynamic weighting based on confidence
if self.dynamic_weight and confidence_scores is not None:
weight = confidence_scores[i]
kl_loss = kl_loss * weight
total_loss = total_loss + kl_loss * (self.temperature ** 2)
offset += group_size
return total_loss / len(groups)
class MultiTaskLoss(nn.Module):
"""
Multi-task loss for combining multiple objectives
Notes:
- All parameters (loss_weights, adaptive_weights, etc.) should be passed from config-driven training scripts.
- This class does not load config directly; pass config values from your main script for consistency.
"""
def __init__(
self,
loss_weights: Optional[Dict[str, float]] = None,
adaptive_weights: bool = False
):
super().__init__()
self.loss_weights = loss_weights or {
'dense': 1.0,
'sparse': 0.3,
'colbert': 0.5,
'distillation': 1.0
}
self.adaptive_weights = adaptive_weights
if adaptive_weights:
# Learnable loss weights (uncertainty weighting)
self.log_vars = nn.Parameter(torch.zeros(len(self.loss_weights)))
def forward(self, losses: Dict[str, torch.Tensor]) -> Tuple[torch.Tensor, Dict[str, float]]:
"""
Args:
losses: Dictionary of individual losses
Returns:
total_loss: Combined loss
loss_dict: Dictionary of individual loss values
"""
total_loss = torch.tensor(0.0, device=next(iter(losses.values())).device if losses else 'cpu')
loss_dict = {}
if self.adaptive_weights:
# Uncertainty-based weighting
for i, (loss_name, loss_value) in enumerate(losses.items()):
if loss_name in self.loss_weights:
precision = torch.exp(-self.log_vars[i])
weighted_loss = precision * loss_value + self.log_vars[i]
total_loss = total_loss + weighted_loss
loss_dict[f'{loss_name}_loss'] = loss_value.detach().cpu().item() if isinstance(loss_value, torch.Tensor) else float(loss_value)
loss_dict[f'{loss_name}_weight'] = precision.detach().cpu().item() if isinstance(precision, torch.Tensor) else float(precision)
else:
# Fixed weighting
for loss_name, loss_value in losses.items():
if loss_name in self.loss_weights:
weight = self.loss_weights[loss_name]
weighted_loss = weight * loss_value
total_loss = total_loss + weighted_loss
loss_dict[f'{loss_name}_loss'] = loss_value.detach().cpu().item() if isinstance(loss_value, torch.Tensor) else float(loss_value)
loss_dict[f'{loss_name}_weight'] = weight
loss_dict['total_loss'] = total_loss.detach().cpu().item() if isinstance(total_loss, torch.Tensor) else float(total_loss)
return total_loss, loss_dict
class SparseLoss(nn.Module):
"""
SPLADE+ sparse loss with log1p(ReLU(.)) and FLOPS regularization.
For BGE-M3's token-position based sparse representations.
Args:
flops_loss_weight: Weight for FLOPS regularization
sparse_reg_weight: Weight for sparsity regularization
temperature: Temperature for similarity computation
"""
def __init__(
self,
flops_loss_weight: float = 1e-3,
sparse_reg_weight: float = 1e-4,
temperature: float = 0.02
):
super().__init__()
self.flops_loss_weight = flops_loss_weight
self.sparse_reg_weight = sparse_reg_weight
self.temperature = temperature
self.cross_entropy = nn.CrossEntropyLoss()
def forward(
self,
query_sparse: torch.Tensor,
doc_sparse: torch.Tensor,
labels: torch.Tensor
) -> Tuple[torch.Tensor, Dict[str, float]]:
"""
Args:
query_sparse: [batch, query_seq_len] - Query sparse embeddings (token-position based)
doc_sparse: [batch, doc_seq_len] - Document sparse embeddings (token-position based)
labels: [batch] - Labels for contrastive learning
Returns:
Tuple of (loss, loss_dict)
"""
batch_size = query_sparse.size(0)
device = query_sparse.device
# SPLADE+ uses log1p(ReLU(.))
query_splade = torch.log1p(F.relu(query_sparse))
doc_splade = torch.log1p(F.relu(doc_sparse))
# For BGE-M3's token-position sparse embeddings, compute similarity using aggregated scores
# Since we have different sequence lengths, we need to aggregate to comparable representations
# Sum pooling to get document-level sparse scores
query_scores = query_splade.sum(dim=1, keepdim=True) # [batch, 1]
doc_scores = doc_splade.sum(dim=1, keepdim=True) # [batch, 1]
# Compute similarity matrix for contrastive learning
# Use broadcasting to create [batch, batch] similarity matrix
similarity_matrix = query_scores * doc_scores.t() # [batch, batch]
# Alternative: use cosine similarity approach
# Normalize the aggregated scores
query_norm = F.normalize(query_scores, p=2, dim=1)
doc_norm = F.normalize(doc_scores, p=2, dim=1)
similarity_matrix = torch.matmul(query_norm, doc_norm.t())
# Scale by temperature for contrastive loss
logits = similarity_matrix / self.temperature
# Labels for in-batch contrastive learning (diagonal elements are positive pairs)
contrastive_labels = torch.arange(batch_size, device=device, dtype=torch.long)
# Contrastive loss
main_loss = self.cross_entropy(logits, contrastive_labels)
# FLOPS regularization (encourage sparsity)
flops = (query_splade.sum(dim=-1) + doc_splade.sum(dim=-1)).mean()
# L1 sparsity regularization
sparsity_loss = (query_splade.sum() + doc_splade.sum()) / (query_splade.numel() + doc_splade.numel())
# Combine losses
total_loss = main_loss + self.flops_loss_weight * flops + self.sparse_reg_weight * sparsity_loss
# Ensure all losses are tensors
if not isinstance(total_loss, torch.Tensor):
total_loss = torch.tensor(total_loss, device=device)
if not isinstance(main_loss, torch.Tensor):
main_loss = torch.tensor(main_loss, device=device)
if not isinstance(flops, torch.Tensor):
flops = torch.tensor(flops, device=device)
if not isinstance(sparsity_loss, torch.Tensor):
sparsity_loss = torch.tensor(sparsity_loss, device=device)
loss_dict = {
'main_loss': main_loss.detach().cpu().item(),
'flops': flops.detach().cpu().item(),
'sparsity_loss': sparsity_loss.detach().cpu().item(),
'total_loss': total_loss.detach().cpu().item()
}
return total_loss, loss_dict
class ColBERTLoss(nn.Module):
"""
ColBERT late interaction (maxsim) loss.
Args:
temperature: Softmax temperature
"""
def __init__(self, temperature: float = 0.02):
super().__init__()
self.temperature = temperature
def forward(
self,
query_embeds: torch.Tensor,
doc_embeds: torch.Tensor,
labels: torch.Tensor
) -> torch.Tensor:
"""
Args:
query_embeds: [batch, seq_len, dim]
doc_embeds: [batch, seq_len, dim]
labels: [batch]
Returns:
Loss value
"""
# Late interaction: maxsim over doc tokens for each query token
# [batch, query_len, doc_len]
sim_matrix = torch.einsum('bqd,bkd->bqk', query_embeds, doc_embeds) / self.temperature
maxsim = sim_matrix.max(dim=2)[0] # [batch, query_len]
# Aggregate over query tokens (mean)
scores = maxsim.mean(dim=1) # [batch]
# Binary cross-entropy loss
loss = F.binary_cross_entropy_with_logits(scores, labels.float())
return loss