508 lines
20 KiB
Python
508 lines
20 KiB
Python
"""
|
|
Loss functions for BGE fine-tuning
|
|
"""
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from typing import Dict, List, Optional, Tuple
|
|
import numpy as np
|
|
|
|
|
|
class InfoNCELoss(nn.Module):
|
|
"""
|
|
InfoNCE loss for contrastive learning
|
|
As used in BGE-M3 paper with temperature τ=0.02
|
|
|
|
Notes:
|
|
- All parameters (temperature, reduction, etc.) should be passed from config-driven training scripts.
|
|
- This class does not load config directly; pass config values from your main script for consistency.
|
|
"""
|
|
|
|
def __init__(self, temperature: float = 0.02, reduction: str = 'mean'):
|
|
super().__init__()
|
|
self.temperature = temperature
|
|
self.reduction = reduction
|
|
self.cross_entropy = nn.CrossEntropyLoss(reduction=reduction)
|
|
|
|
def forward(
|
|
self,
|
|
query_embeddings: torch.Tensor,
|
|
positive_embeddings: torch.Tensor,
|
|
negative_embeddings: Optional[torch.Tensor] = None
|
|
) -> torch.Tensor:
|
|
"""
|
|
Compute InfoNCE loss with proper normalization and numerical stability.
|
|
|
|
Args:
|
|
query_embeddings: [batch_size, embedding_dim]
|
|
positive_embeddings: [batch_size, embedding_dim]
|
|
negative_embeddings: [batch_size, num_negatives, embedding_dim] or None
|
|
"""
|
|
batch_size = query_embeddings.size(0)
|
|
device = query_embeddings.device
|
|
|
|
# Normalize embeddings (L2 normalization)
|
|
query_embeddings = F.normalize(query_embeddings, p=2, dim=1)
|
|
positive_embeddings = F.normalize(positive_embeddings, p=2, dim=1)
|
|
|
|
# Compute positive scores (query-positive similarity)
|
|
# Shape: [batch_size]
|
|
pos_scores = torch.sum(query_embeddings * positive_embeddings, dim=1)
|
|
pos_logits = pos_scores / self.temperature
|
|
|
|
# Compute negative scores if provided
|
|
if negative_embeddings is not None and negative_embeddings.numel() > 0:
|
|
# Normalize negative embeddings
|
|
negative_embeddings = F.normalize(negative_embeddings, p=2, dim=-1)
|
|
# Compute negative scores
|
|
# Shape: [batch_size, num_negatives]
|
|
neg_scores = torch.matmul(
|
|
query_embeddings.unsqueeze(1),
|
|
negative_embeddings.transpose(1, 2)
|
|
).squeeze(1)
|
|
neg_logits = neg_scores / self.temperature
|
|
|
|
# Concatenate positive and negative logits
|
|
# Shape: [batch_size, 1 + num_negatives]
|
|
logits = torch.cat([pos_logits.unsqueeze(1), neg_logits], dim=1)
|
|
|
|
# Labels: positive is always at index 0
|
|
labels = torch.zeros(batch_size, device=device, dtype=torch.long)
|
|
else:
|
|
# Use in-batch negatives as fallback
|
|
# Compute similarity matrix for in-batch negatives
|
|
similarity_matrix = torch.matmul(query_embeddings, positive_embeddings.t())
|
|
logits = similarity_matrix / self.temperature
|
|
|
|
# Labels: positive pairs are on the diagonal
|
|
labels = torch.arange(batch_size, device=device, dtype=torch.long)
|
|
|
|
# Compute cross-entropy loss
|
|
loss = self.cross_entropy(logits, labels)
|
|
|
|
return loss
|
|
|
|
|
|
class M3KnowledgeDistillationLoss(nn.Module):
|
|
"""
|
|
Knowledge distillation loss for BGE-M3's multi-functionality training
|
|
Distills knowledge from dense, sparse, and multi-vector representations
|
|
|
|
Notes:
|
|
- All parameters (temperature, alpha, distill_types, etc.) should be passed from config-driven training scripts.
|
|
- This class does not load config directly; pass config values from your main script for consistency.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
temperature: float = 4.0,
|
|
alpha: float = 0.5,
|
|
distill_types: List[str] = ['dense', 'sparse', 'colbert']
|
|
):
|
|
super().__init__()
|
|
self.temperature = temperature
|
|
self.alpha = alpha
|
|
self.distill_types = distill_types
|
|
self.kl_div = nn.KLDivLoss(reduction='batchmean')
|
|
|
|
def forward(
|
|
self,
|
|
student_outputs: Dict[str, torch.Tensor],
|
|
teacher_outputs: Dict[str, torch.Tensor],
|
|
labels: Optional[torch.Tensor] = None
|
|
) -> Tuple[torch.Tensor, Dict[str, float]]:
|
|
"""
|
|
Args:
|
|
student_outputs: Dictionary with 'dense', 'sparse', 'colbert' outputs
|
|
teacher_outputs: Dictionary with teacher model outputs
|
|
labels: Optional labels for supervised loss
|
|
|
|
Returns:
|
|
total_loss: Combined distillation loss
|
|
loss_dict: Dictionary of individual losses
|
|
"""
|
|
total_loss = torch.tensor(0.0, device=next(iter(student_outputs.values())).device if student_outputs else 'cpu')
|
|
loss_dict = {}
|
|
|
|
# Distillation loss for each representation type
|
|
for repr_type in self.distill_types:
|
|
if repr_type in student_outputs and repr_type in teacher_outputs:
|
|
student_logits = student_outputs[repr_type]
|
|
teacher_logits = teacher_outputs[repr_type]
|
|
|
|
# Apply temperature scaling
|
|
student_log_probs = F.log_softmax(student_logits / self.temperature, dim=-1)
|
|
teacher_probs = F.softmax(teacher_logits / self.temperature, dim=-1)
|
|
|
|
# KL divergence loss
|
|
kl_loss = self.kl_div(student_log_probs, teacher_probs) * (self.temperature ** 2)
|
|
total_loss = total_loss + self.alpha * kl_loss
|
|
loss_dict[f'{repr_type}_distill_loss'] = kl_loss.detach().cpu().item()
|
|
|
|
# Add supervised loss if labels provided
|
|
if labels is not None and 'dense' in student_outputs:
|
|
ce_loss = F.cross_entropy(student_outputs['dense'], labels)
|
|
total_loss = total_loss + (1 - self.alpha) * ce_loss
|
|
loss_dict['supervised_loss'] = ce_loss.detach().cpu().item()
|
|
|
|
loss_dict['total_loss'] = total_loss.detach().cpu().item()
|
|
return total_loss, loss_dict
|
|
|
|
|
|
class ListwiseCrossEntropyLoss(nn.Module):
|
|
"""
|
|
Listwise cross-entropy loss for reranker training
|
|
|
|
Notes:
|
|
- All parameters (temperature, etc.) should be passed from config-driven training scripts.
|
|
- This class does not load config directly; pass config values from your main script for consistency.
|
|
"""
|
|
|
|
def __init__(self, temperature: float = 1.0):
|
|
super().__init__()
|
|
self.temperature = temperature
|
|
|
|
def forward(self, scores: torch.Tensor, labels: torch.Tensor, groups: List[int]) -> torch.Tensor:
|
|
"""
|
|
Compute listwise cross-entropy loss with proper normalization.
|
|
|
|
Args:
|
|
scores: Relevance scores [total_passages]
|
|
labels: Binary relevance labels [total_passages]
|
|
groups: List of group sizes (passages per query)
|
|
"""
|
|
total_loss = torch.tensor(0.0, device=scores.device, dtype=scores.dtype)
|
|
offset = 0
|
|
|
|
for group_size in groups:
|
|
if group_size == 0:
|
|
continue
|
|
|
|
# Get scores and labels for this group
|
|
group_scores = scores[offset:offset + group_size]
|
|
group_labels = labels[offset:offset + group_size]
|
|
|
|
# Skip if no positive labels
|
|
if group_labels.sum() == 0:
|
|
offset += group_size
|
|
continue
|
|
|
|
# Apply temperature scaling
|
|
scaled_scores = group_scores / self.temperature
|
|
|
|
# Apply log_softmax for numerical stability
|
|
log_probs = F.log_softmax(scaled_scores, dim=0)
|
|
|
|
# Normalize labels to create a probability distribution
|
|
label_probs = group_labels.float() / group_labels.sum()
|
|
|
|
# Compute cross-entropy: -sum(p_true * log(p_pred))
|
|
loss = -(label_probs * log_probs).sum()
|
|
|
|
total_loss = total_loss + loss
|
|
offset += group_size
|
|
|
|
# Average over number of groups
|
|
num_valid_groups = sum(1 for g in groups if g > 0)
|
|
if num_valid_groups > 0:
|
|
return total_loss / num_valid_groups
|
|
else:
|
|
return total_loss
|
|
|
|
|
|
class PairwiseMarginLoss(nn.Module):
|
|
"""
|
|
Pairwise margin ranking loss for reranker
|
|
|
|
Notes:
|
|
- All parameters (margin, etc.) should be passed from config-driven training scripts.
|
|
- This class does not load config directly; pass config values from your main script for consistency.
|
|
"""
|
|
|
|
def __init__(self, margin: float = 1.0):
|
|
super().__init__()
|
|
self.margin = margin
|
|
self.ranking_loss = nn.MarginRankingLoss(margin=margin)
|
|
|
|
def forward(self, scores: torch.Tensor, labels: torch.Tensor, groups: List[int]) -> torch.Tensor:
|
|
"""
|
|
Args:
|
|
scores: Relevance scores [total_passages]
|
|
labels: Binary relevance labels [total_passages]
|
|
groups: List of group sizes
|
|
"""
|
|
total_loss = torch.tensor(0.0, device=scores.device)
|
|
num_pairs = 0
|
|
offset = 0
|
|
|
|
for group_size in groups:
|
|
group_scores = scores[offset:offset + group_size]
|
|
group_labels = labels[offset:offset + group_size]
|
|
|
|
# Find positive and negative indices
|
|
pos_indices = torch.where(group_labels == 1)[0]
|
|
neg_indices = torch.where(group_labels == 0)[0]
|
|
|
|
if len(pos_indices) > 0 and len(neg_indices) > 0:
|
|
# Create all positive-negative pairs
|
|
for pos_idx in pos_indices:
|
|
for neg_idx in neg_indices:
|
|
pos_score = group_scores[pos_idx].unsqueeze(0)
|
|
neg_score = group_scores[neg_idx].unsqueeze(0)
|
|
target = torch.ones(1, device=scores.device)
|
|
|
|
loss = self.ranking_loss(pos_score, neg_score, target)
|
|
total_loss = total_loss + loss
|
|
num_pairs += 1
|
|
|
|
offset += group_size
|
|
|
|
return total_loss / max(num_pairs, 1)
|
|
|
|
|
|
class DynamicListwiseDistillationLoss(nn.Module):
|
|
"""
|
|
Dynamic listwise distillation loss for RocketQAv2-style joint training
|
|
|
|
Notes:
|
|
- All parameters (temperature, dynamic_weight, etc.) should be passed from config-driven training scripts.
|
|
- This class does not load config directly; pass config values from your main script for consistency.
|
|
"""
|
|
|
|
def __init__(self, temperature: float = 3.0, dynamic_weight: bool = True):
|
|
super().__init__()
|
|
self.temperature = temperature
|
|
self.dynamic_weight = dynamic_weight
|
|
self.kl_div = nn.KLDivLoss(reduction='none')
|
|
|
|
def forward(
|
|
self,
|
|
student_scores: torch.Tensor,
|
|
teacher_scores: torch.Tensor,
|
|
groups: List[int],
|
|
confidence_scores: Optional[torch.Tensor] = None
|
|
) -> torch.Tensor:
|
|
"""
|
|
Args:
|
|
student_scores: Student model scores [total_passages]
|
|
teacher_scores: Teacher model scores [total_passages]
|
|
groups: List of group sizes
|
|
confidence_scores: Optional confidence for dynamic weighting
|
|
"""
|
|
total_loss = torch.tensor(0.0, device=student_scores.device)
|
|
offset = 0
|
|
|
|
for i, group_size in enumerate(groups):
|
|
student_group = student_scores[offset:offset + group_size] / self.temperature
|
|
teacher_group = teacher_scores[offset:offset + group_size] / self.temperature
|
|
|
|
# Convert to distributions
|
|
student_probs = F.log_softmax(student_group, dim=0)
|
|
teacher_probs = F.softmax(teacher_group, dim=0)
|
|
|
|
# Compute KL divergence
|
|
kl_loss = self.kl_div(student_probs, teacher_probs).sum()
|
|
|
|
# Apply dynamic weighting based on confidence
|
|
if self.dynamic_weight and confidence_scores is not None:
|
|
weight = confidence_scores[i]
|
|
kl_loss = kl_loss * weight
|
|
|
|
total_loss = total_loss + kl_loss * (self.temperature ** 2)
|
|
offset += group_size
|
|
|
|
return total_loss / len(groups)
|
|
|
|
|
|
class MultiTaskLoss(nn.Module):
|
|
"""
|
|
Multi-task loss for combining multiple objectives
|
|
|
|
Notes:
|
|
- All parameters (loss_weights, adaptive_weights, etc.) should be passed from config-driven training scripts.
|
|
- This class does not load config directly; pass config values from your main script for consistency.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
loss_weights: Optional[Dict[str, float]] = None,
|
|
adaptive_weights: bool = False
|
|
):
|
|
super().__init__()
|
|
self.loss_weights = loss_weights or {
|
|
'dense': 1.0,
|
|
'sparse': 0.3,
|
|
'colbert': 0.5,
|
|
'distillation': 1.0
|
|
}
|
|
self.adaptive_weights = adaptive_weights
|
|
|
|
if adaptive_weights:
|
|
# Learnable loss weights (uncertainty weighting)
|
|
self.log_vars = nn.Parameter(torch.zeros(len(self.loss_weights)))
|
|
|
|
def forward(self, losses: Dict[str, torch.Tensor]) -> Tuple[torch.Tensor, Dict[str, float]]:
|
|
"""
|
|
Args:
|
|
losses: Dictionary of individual losses
|
|
|
|
Returns:
|
|
total_loss: Combined loss
|
|
loss_dict: Dictionary of individual loss values
|
|
"""
|
|
total_loss = torch.tensor(0.0, device=next(iter(losses.values())).device if losses else 'cpu')
|
|
loss_dict = {}
|
|
|
|
if self.adaptive_weights:
|
|
# Uncertainty-based weighting
|
|
for i, (loss_name, loss_value) in enumerate(losses.items()):
|
|
if loss_name in self.loss_weights:
|
|
precision = torch.exp(-self.log_vars[i])
|
|
weighted_loss = precision * loss_value + self.log_vars[i]
|
|
total_loss = total_loss + weighted_loss
|
|
loss_dict[f'{loss_name}_loss'] = loss_value.detach().cpu().item() if isinstance(loss_value, torch.Tensor) else float(loss_value)
|
|
loss_dict[f'{loss_name}_weight'] = precision.detach().cpu().item() if isinstance(precision, torch.Tensor) else float(precision)
|
|
else:
|
|
# Fixed weighting
|
|
for loss_name, loss_value in losses.items():
|
|
if loss_name in self.loss_weights:
|
|
weight = self.loss_weights[loss_name]
|
|
weighted_loss = weight * loss_value
|
|
total_loss = total_loss + weighted_loss
|
|
loss_dict[f'{loss_name}_loss'] = loss_value.detach().cpu().item() if isinstance(loss_value, torch.Tensor) else float(loss_value)
|
|
loss_dict[f'{loss_name}_weight'] = weight
|
|
|
|
loss_dict['total_loss'] = total_loss.detach().cpu().item() if isinstance(total_loss, torch.Tensor) else float(total_loss)
|
|
return total_loss, loss_dict
|
|
|
|
|
|
class SparseLoss(nn.Module):
|
|
"""
|
|
SPLADE+ sparse loss with log1p(ReLU(.)) and FLOPS regularization.
|
|
For BGE-M3's token-position based sparse representations.
|
|
Args:
|
|
flops_loss_weight: Weight for FLOPS regularization
|
|
sparse_reg_weight: Weight for sparsity regularization
|
|
temperature: Temperature for similarity computation
|
|
"""
|
|
def __init__(
|
|
self,
|
|
flops_loss_weight: float = 1e-3,
|
|
sparse_reg_weight: float = 1e-4,
|
|
temperature: float = 0.02
|
|
):
|
|
super().__init__()
|
|
self.flops_loss_weight = flops_loss_weight
|
|
self.sparse_reg_weight = sparse_reg_weight
|
|
self.temperature = temperature
|
|
self.cross_entropy = nn.CrossEntropyLoss()
|
|
|
|
def forward(
|
|
self,
|
|
query_sparse: torch.Tensor,
|
|
doc_sparse: torch.Tensor,
|
|
labels: torch.Tensor
|
|
) -> Tuple[torch.Tensor, Dict[str, float]]:
|
|
"""
|
|
Args:
|
|
query_sparse: [batch, query_seq_len] - Query sparse embeddings (token-position based)
|
|
doc_sparse: [batch, doc_seq_len] - Document sparse embeddings (token-position based)
|
|
labels: [batch] - Labels for contrastive learning
|
|
Returns:
|
|
Tuple of (loss, loss_dict)
|
|
"""
|
|
batch_size = query_sparse.size(0)
|
|
device = query_sparse.device
|
|
|
|
# SPLADE+ uses log1p(ReLU(.))
|
|
query_splade = torch.log1p(F.relu(query_sparse))
|
|
doc_splade = torch.log1p(F.relu(doc_sparse))
|
|
|
|
# For BGE-M3's token-position sparse embeddings, compute similarity using aggregated scores
|
|
# Since we have different sequence lengths, we need to aggregate to comparable representations
|
|
|
|
# Sum pooling to get document-level sparse scores
|
|
query_scores = query_splade.sum(dim=1, keepdim=True) # [batch, 1]
|
|
doc_scores = doc_splade.sum(dim=1, keepdim=True) # [batch, 1]
|
|
|
|
# Compute similarity matrix for contrastive learning
|
|
# Use broadcasting to create [batch, batch] similarity matrix
|
|
similarity_matrix = query_scores * doc_scores.t() # [batch, batch]
|
|
|
|
# Alternative: use cosine similarity approach
|
|
# Normalize the aggregated scores
|
|
query_norm = F.normalize(query_scores, p=2, dim=1)
|
|
doc_norm = F.normalize(doc_scores, p=2, dim=1)
|
|
similarity_matrix = torch.matmul(query_norm, doc_norm.t())
|
|
|
|
# Scale by temperature for contrastive loss
|
|
logits = similarity_matrix / self.temperature
|
|
|
|
# Labels for in-batch contrastive learning (diagonal elements are positive pairs)
|
|
contrastive_labels = torch.arange(batch_size, device=device, dtype=torch.long)
|
|
|
|
# Contrastive loss
|
|
main_loss = self.cross_entropy(logits, contrastive_labels)
|
|
|
|
# FLOPS regularization (encourage sparsity)
|
|
flops = (query_splade.sum(dim=-1) + doc_splade.sum(dim=-1)).mean()
|
|
|
|
# L1 sparsity regularization
|
|
sparsity_loss = (query_splade.sum() + doc_splade.sum()) / (query_splade.numel() + doc_splade.numel())
|
|
|
|
# Combine losses
|
|
total_loss = main_loss + self.flops_loss_weight * flops + self.sparse_reg_weight * sparsity_loss
|
|
|
|
# Ensure all losses are tensors
|
|
if not isinstance(total_loss, torch.Tensor):
|
|
total_loss = torch.tensor(total_loss, device=device)
|
|
if not isinstance(main_loss, torch.Tensor):
|
|
main_loss = torch.tensor(main_loss, device=device)
|
|
if not isinstance(flops, torch.Tensor):
|
|
flops = torch.tensor(flops, device=device)
|
|
if not isinstance(sparsity_loss, torch.Tensor):
|
|
sparsity_loss = torch.tensor(sparsity_loss, device=device)
|
|
|
|
loss_dict = {
|
|
'main_loss': main_loss.detach().cpu().item(),
|
|
'flops': flops.detach().cpu().item(),
|
|
'sparsity_loss': sparsity_loss.detach().cpu().item(),
|
|
'total_loss': total_loss.detach().cpu().item()
|
|
}
|
|
|
|
return total_loss, loss_dict
|
|
|
|
|
|
class ColBERTLoss(nn.Module):
|
|
"""
|
|
ColBERT late interaction (maxsim) loss.
|
|
Args:
|
|
temperature: Softmax temperature
|
|
"""
|
|
def __init__(self, temperature: float = 0.02):
|
|
super().__init__()
|
|
self.temperature = temperature
|
|
def forward(
|
|
self,
|
|
query_embeds: torch.Tensor,
|
|
doc_embeds: torch.Tensor,
|
|
labels: torch.Tensor
|
|
) -> torch.Tensor:
|
|
"""
|
|
Args:
|
|
query_embeds: [batch, seq_len, dim]
|
|
doc_embeds: [batch, seq_len, dim]
|
|
labels: [batch]
|
|
Returns:
|
|
Loss value
|
|
"""
|
|
# Late interaction: maxsim over doc tokens for each query token
|
|
# [batch, query_len, doc_len]
|
|
sim_matrix = torch.einsum('bqd,bkd->bqk', query_embeds, doc_embeds) / self.temperature
|
|
maxsim = sim_matrix.max(dim=2)[0] # [batch, query_len]
|
|
# Aggregate over query tokens (mean)
|
|
scores = maxsim.mean(dim=1) # [batch]
|
|
# Binary cross-entropy loss
|
|
loss = F.binary_cross_entropy_with_logits(scores, labels.float())
|
|
return loss
|