init commit

This commit is contained in:
ldy
2025-07-22 16:55:25 +08:00
commit 36003b83e2
67 changed files with 76613 additions and 0 deletions

25
models/__init__.py Normal file
View File

@@ -0,0 +1,25 @@
"""
Model implementations for BGE fine-tuning
"""
from .bge_m3 import BGEM3Model
from .bge_reranker import BGERerankerModel
from .losses import (
InfoNCELoss,
M3KnowledgeDistillationLoss,
ListwiseCrossEntropyLoss,
PairwiseMarginLoss,
DynamicListwiseDistillationLoss,
MultiTaskLoss
)
__all__ = [
'BGEM3Model',
'BGERerankerModel',
'InfoNCELoss',
'M3KnowledgeDistillationLoss',
'ListwiseCrossEntropyLoss',
'PairwiseMarginLoss',
'DynamicListwiseDistillationLoss',
'MultiTaskLoss'
]

768
models/bge_m3.py Normal file
View File

@@ -0,0 +1,768 @@
"""
BGE-M3 model wrapper for multi-functionality embedding
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoModel, AutoTokenizer, AutoConfig
from typing import Dict, List, Optional, Union, Tuple
import numpy as np
import logging
from utils.config_loader import get_config
import os
logger = logging.getLogger(__name__)
def load_model_and_tokenizer(model_name_or_path, cache_dir=None):
"""
Load model and tokenizer with comprehensive error handling and fallback strategies
Args:
model_name_or_path: Model identifier or local path
cache_dir: Cache directory for downloaded models
Returns:
Tuple of (model, tokenizer, config)
Raises:
RuntimeError: If model loading fails after all attempts
"""
config = get_config()
model_source = config.get('model_paths', 'source', 'huggingface')
# Track loading attempts for debugging
attempts = []
# Try primary loading method
try:
if model_source == 'huggingface':
return _load_from_huggingface(model_name_or_path, cache_dir)
elif model_source == 'modelscope':
return _load_from_modelscope(model_name_or_path, cache_dir)
else:
raise ValueError(f"Unknown model source: {model_source}")
except Exception as e:
attempts.append(f"{model_source}: {str(e)}")
logger.warning(f"Primary loading from {model_source} failed: {e}")
# Fallback strategies
fallback_sources = ['huggingface'] if model_source == 'modelscope' else ['modelscope']
for fallback in fallback_sources:
try:
logger.info(f"Attempting fallback loading from {fallback}")
if fallback == 'huggingface':
return _load_from_huggingface(model_name_or_path, cache_dir)
elif fallback == 'modelscope':
return _load_from_modelscope(model_name_or_path, cache_dir)
except Exception as e:
attempts.append(f"{fallback}: {str(e)}")
logger.warning(f"Fallback loading from {fallback} failed: {e}")
# Try loading from local path if it exists
if os.path.exists(model_name_or_path):
try:
logger.info(f"Attempting to load from local path: {model_name_or_path}")
return _load_from_local(model_name_or_path, cache_dir)
except Exception as e:
attempts.append(f"local: {str(e)}")
logger.warning(f"Local loading failed: {e}")
# All attempts failed
error_msg = f"Failed to load model from any source. Attempts:\n" + "\n".join(attempts)
logger.error(error_msg)
raise RuntimeError(error_msg)
def _load_from_huggingface(model_name_or_path, cache_dir=None):
"""Load model from HuggingFace with error handling"""
try:
from transformers import AutoModel, AutoTokenizer, AutoConfig
# Load config first to validate model
model_config = AutoConfig.from_pretrained(
model_name_or_path,
cache_dir=cache_dir,
trust_remote_code=True
)
# Load model
model = AutoModel.from_pretrained(
model_name_or_path,
config=model_config,
cache_dir=cache_dir,
trust_remote_code=True
)
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(
model_name_or_path,
cache_dir=cache_dir,
trust_remote_code=True
)
logger.info(f"Successfully loaded model from HuggingFace: {model_name_or_path}")
return model, tokenizer, model_config
except ImportError as e:
raise ImportError(f"Transformers library not available: {e}")
except Exception as e:
raise RuntimeError(f"HuggingFace loading error: {e}")
def _load_from_modelscope(model_name_or_path, cache_dir=None):
"""Load model from ModelScope with error handling"""
try:
from modelscope.models import Model as MSModel
from modelscope.hub.snapshot_download import snapshot_download as ms_snapshot_download
# Download model if needed
local_path = ms_snapshot_download(model_id=model_name_or_path, cache_dir=cache_dir)
# Load model
model = MSModel.from_pretrained(local_path)
# Try to load tokenizer
tokenizer = None
try:
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(local_path)
except Exception as e:
logger.warning(f"Could not load tokenizer from ModelScope: {e}")
# Try to get config
config = None
config_path = os.path.join(local_path, 'config.json')
if os.path.exists(config_path):
import json
with open(config_path, 'r') as f:
config = json.load(f)
logger.info(f"Successfully loaded model from ModelScope: {model_name_or_path}")
return model, tokenizer, config
except ImportError as e:
raise ImportError(f"ModelScope library not available: {e}")
except Exception as e:
raise RuntimeError(f"ModelScope loading error: {e}")
def _load_from_local(model_path, cache_dir=None):
"""Load model from local directory with error handling"""
try:
from transformers import AutoModel, AutoTokenizer, AutoConfig
if not os.path.isdir(model_path):
raise ValueError(f"Local path is not a directory: {model_path}")
# Check for required files
required_files = ['config.json']
missing_files = [f for f in required_files if not os.path.exists(os.path.join(model_path, f))]
if missing_files:
raise FileNotFoundError(f"Missing required files: {missing_files}")
# Load components
model_config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
model = AutoModel.from_pretrained(model_path, config=model_config, trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
logger.info(f"Successfully loaded model from local path: {model_path}")
return model, tokenizer, model_config
except Exception as e:
raise RuntimeError(f"Local loading error: {e}")
class BGEM3Model(nn.Module):
"""
BGE-M3 model wrapper supporting dense, sparse, and multi-vector retrieval
Notes:
- All parameters (model_name_or_path, use_dense, use_sparse, use_colbert, pooling_method, etc.) should be passed from config-driven training scripts.
- This class does not load config directly; pass config values from your main script for consistency.
"""
def __init__(
self,
model_name_or_path: str = "BAAI/bge-m3",
use_dense: bool = True,
use_sparse: bool = True,
use_colbert: bool = True,
pooling_method: str = "cls",
normalize_embeddings: bool = True,
sentence_pooling_method: str = "cls",
colbert_dim: int = -1,
sparse_top_k: int = 100,
cache_dir: Optional[str] = None,
device: Optional[str] = None,
gradient_checkpointing: bool = False
):
super().__init__()
self.model_name_or_path = model_name_or_path
self.use_dense = use_dense
self.use_sparse = use_sparse
self.use_colbert = use_colbert
self.pooling_method = pooling_method
self.normalize_embeddings = normalize_embeddings
self.colbert_dim = colbert_dim
self.sparse_top_k = sparse_top_k
self.gradient_checkpointing = gradient_checkpointing
# Initialize model with comprehensive error handling
try:
logger.info(f"Loading BGE-M3 model from {model_name_or_path} using source from config.toml...")
self.model, self.tokenizer, self.config = load_model_and_tokenizer(
model_name_or_path,
cache_dir=cache_dir
)
logger.info("BGE-M3 model loaded successfully")
except Exception as e:
logger.error(f"Failed to load BGE-M3 model: {e}")
raise RuntimeError(f"Model initialization failed: {e}")
# Enable gradient checkpointing if requested
if self.gradient_checkpointing:
self.enable_gradient_checkpointing()
# Ensure we have a valid config
if self.config is None:
logger.warning("Model config is None, creating a basic config")
# Create a basic config object
class BasicConfig:
def __init__(self):
self.hidden_size = 768 # Default BERT hidden size
self.config = BasicConfig()
# Initialize heads for different functionalities
hidden_size = getattr(self.config, 'hidden_size', 768)
if self.use_dense:
# Dense retrieval doesn't need additional layers (uses CLS token)
self.dense_linear = nn.Identity()
if self.use_sparse:
# Sparse retrieval head (SPLADE-style)
self.sparse_linear = nn.Linear(hidden_size, 1)
self.sparse_activation = nn.ReLU()
if self.use_colbert:
# ColBERT projection layer
colbert_dim = colbert_dim if colbert_dim > 0 else hidden_size
self.colbert_linear = nn.Linear(hidden_size, colbert_dim)
# Set device with enhanced detection
if device:
self.device = torch.device(device)
else:
self.device = self._detect_device()
try:
self.to(self.device)
logger.info(f"BGE-M3 model moved to device: {self.device}")
except Exception as e:
logger.warning(f"Failed to move model to {self.device}: {e}, using CPU")
self.device = torch.device('cpu')
self.to(self.device)
# Cache for tokenizer
self._tokenizer = None
def enable_gradient_checkpointing(self):
"""Enable gradient checkpointing for memory efficiency"""
if hasattr(self.model, 'gradient_checkpointing_enable'):
self.model.gradient_checkpointing_enable()
logger.info("Gradient checkpointing enabled for BGE-M3")
else:
logger.warning("Model does not support gradient checkpointing")
def disable_gradient_checkpointing(self):
"""Disable gradient checkpointing"""
if hasattr(self.model, 'gradient_checkpointing_disable'):
self.model.gradient_checkpointing_disable()
logger.info("Gradient checkpointing disabled for BGE-M3")
def _detect_device(self) -> torch.device:
"""Enhanced device detection with proper error handling"""
# Check Ascend NPU first
try:
import torch_npu
if hasattr(torch, 'npu') and torch.npu.is_available(): # type: ignore[attr-defined]
logger.info("Using Ascend NPU")
return torch.device("npu")
except ImportError:
pass
except Exception as e:
logger.debug(f"NPU detection failed: {e}")
pass
# Check CUDA
if torch.cuda.is_available():
logger.info("Using CUDA GPU")
return torch.device("cuda")
# Check MPS (Apple Silicon)
try:
if hasattr(torch, 'backends') and hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): # type: ignore[attr-defined]
logger.info("Using Apple MPS")
return torch.device("mps")
except Exception as e:
logger.debug(f"MPS detection failed: {e}")
pass
# Default to CPU
logger.info("Using CPU")
return torch.device("cpu")
def _get_tokenizer(self):
"""Lazy load tokenizer with error handling"""
if self._tokenizer is None:
try:
self._tokenizer = AutoTokenizer.from_pretrained(
self.model_name_or_path,
trust_remote_code=True
)
logger.info("Tokenizer loaded successfully")
except Exception as e:
logger.error(f"Failed to load tokenizer: {e}")
raise RuntimeError(f"Tokenizer loading failed: {e}")
return self._tokenizer
def forward(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
token_type_ids: Optional[torch.Tensor] = None,
return_dense: bool = True,
return_sparse: bool = False,
return_colbert: bool = False,
return_dict: bool = True
) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
"""
Forward pass for BGE-M3 with dense, sparse (SPLADE+), and ColBERT heads.
Args:
input_ids: Token IDs [batch_size, seq_length]
attention_mask: Attention mask [batch_size, seq_length]
token_type_ids: Token type IDs (optional)
return_dense: Return dense embeddings
return_sparse: Return sparse embeddings
return_colbert: Return ColBERT embeddings
return_dict: Return dictionary of outputs
Returns:
Embeddings or dictionary of embeddings
"""
# Validate inputs
if input_ids is None or attention_mask is None:
raise ValueError("input_ids and attention_mask are required")
if input_ids.shape != attention_mask.shape:
raise ValueError("input_ids and attention_mask must have the same shape")
# Determine what to return based on model configuration and parameters
return_dense = return_dense and self.use_dense
return_sparse = return_sparse and self.use_sparse
return_colbert = return_colbert and self.use_colbert
if not (return_dense or return_sparse or return_colbert):
logger.warning("No output head selected in forward. At least one of dense, sparse, or colbert must be True.")
return_dense = True # Default to dense if nothing selected
try:
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
return_dict=True
)
last_hidden_state = outputs.last_hidden_state # [batch_size, seq_length, hidden_size]
except Exception as e:
logger.error(f"Model forward pass failed: {e}")
raise RuntimeError(f"Forward pass failed: {e}")
results = {}
# Dense embeddings (CLS or mean pooling)
if return_dense:
if self.pooling_method == "cls" or self.sentence_pooling_method == "cls":
dense_vecs = last_hidden_state[:, 0]
elif self.pooling_method == "mean" or self.sentence_pooling_method == "mean":
expanded_mask = attention_mask.unsqueeze(-1).expand(last_hidden_state.shape)
masked_embeddings = last_hidden_state * expanded_mask
sum_embeddings = torch.sum(masked_embeddings, dim=1)
sum_mask = torch.clamp(attention_mask.sum(dim=1, keepdim=True), min=1e-9)
dense_vecs = sum_embeddings / sum_mask
else:
raise ValueError(f"Unknown pooling method: {self.pooling_method}")
dense_vecs = self.dense_linear(dense_vecs)
# Ensure embeddings are float type
dense_vecs = dense_vecs.float()
if self.normalize_embeddings:
dense_vecs = F.normalize(dense_vecs, p=2, dim=-1)
results['dense'] = dense_vecs
# SPLADE+ sparse head
if return_sparse:
# SPLADE+ uses log1p(ReLU(.))
sparse_logits = self.sparse_linear(last_hidden_state).squeeze(-1) # [batch, seq_len]
sparse_activations = torch.log1p(F.relu(sparse_logits))
# Mask out padding tokens
sparse_activations = sparse_activations * attention_mask.float()
# Optionally normalize sparse activations
if self.normalize_embeddings:
sparse_activations = F.normalize(sparse_activations, p=2, dim=-1)
results['sparse'] = sparse_activations
# FLOPS regularization (for loss, not forward, but add as attribute)
self.splade_flops = torch.sum(sparse_activations)
# ColBERT head with late interaction (maxsim)
if return_colbert:
colbert_vecs = self.colbert_linear(last_hidden_state) # [batch, seq_len, colbert_dim]
# Mask out padding tokens
expanded_mask = attention_mask.unsqueeze(-1).expand(colbert_vecs.shape)
colbert_vecs = colbert_vecs * expanded_mask
if self.normalize_embeddings:
colbert_vecs = F.normalize(colbert_vecs, p=2, dim=-1)
results['colbert'] = colbert_vecs
if return_dict:
return results
# If only one head requested, return that
if return_dense and len(results) == 1:
return results['dense']
if return_sparse and len(results) == 1:
return results['sparse']
if return_colbert and len(results) == 1:
return results['colbert']
return results
def encode(
self,
sentences: Union[str, List[str]],
batch_size: int = 32,
max_length: int = 512,
convert_to_numpy: bool = True,
convert_to_tensor: bool = False,
show_progress_bar: bool = False,
normalize_embeddings: bool = True,
return_dense: bool = True,
return_sparse: bool = False,
return_colbert: bool = False,
device: Optional[str] = None
) -> Union[np.ndarray, torch.Tensor, Dict[str, Union[np.ndarray, torch.Tensor]]]:
"""
Complete encode implementation with comprehensive error handling
Args:
sentences: Input sentences
batch_size: Batch size for encoding
max_length: Maximum sequence length
convert_to_numpy: Convert to numpy arrays
convert_to_tensor: Keep as tensors
show_progress_bar: Show progress bar
normalize_embeddings: Normalize embeddings
return_dense/sparse/colbert: Which embeddings to return
device: Device to use
Returns:
Embeddings
"""
# Input validation
if isinstance(sentences, str):
sentences = [sentences]
if not sentences:
raise ValueError("Empty sentences list provided")
if batch_size <= 0:
raise ValueError("batch_size must be positive")
if max_length <= 0:
raise ValueError("max_length must be positive")
# Get tokenizer
try:
tokenizer = self._get_tokenizer()
except Exception as e:
raise RuntimeError(f"Failed to load tokenizer: {e}")
# Determine target device
target_device = device if device else self.device
# Initialize storage for embeddings
all_embeddings = {
'dense': [] if return_dense else None,
'sparse': [] if return_sparse else None,
'colbert': [] if return_colbert else None
}
# Set model to evaluation mode
self.eval()
try:
# Import tqdm if show_progress_bar is True
if show_progress_bar:
try:
from tqdm import tqdm
iterator = tqdm(
range(0, len(sentences), batch_size),
desc="Encoding",
total=(len(sentences) + batch_size - 1) // batch_size
)
except ImportError:
logger.warning("tqdm not available, disabling progress bar")
iterator = range(0, len(sentences), batch_size)
else:
iterator = range(0, len(sentences), batch_size)
# Process in batches
for i in iterator:
batch_sentences = sentences[i:i + batch_size]
# Tokenize with proper error handling
try:
encoded = tokenizer(
batch_sentences,
padding=True,
truncation=True,
max_length=max_length,
return_tensors='pt'
)
except Exception as e:
raise RuntimeError(f"Tokenization failed for batch {i//batch_size + 1}: {e}")
# Move to device
try:
encoded = {k: v.to(target_device) for k, v in encoded.items()}
except Exception as e:
raise RuntimeError(f"Failed to move batch to device {target_device}: {e}")
# Get embeddings
with torch.no_grad():
try:
outputs = self.forward(
**encoded,
return_dense=return_dense,
return_sparse=return_sparse,
return_colbert=return_colbert,
return_dict=True
)
except Exception as e:
raise RuntimeError(f"Forward pass failed for batch {i//batch_size + 1}: {e}")
# Collect embeddings
for key in ['dense', 'sparse', 'colbert']:
if all_embeddings[key] is not None and key in outputs:
all_embeddings[key].append(outputs[key].cpu()) # type: ignore[arg-type]
# Concatenate and convert results
final_embeddings = {}
for key in ['dense', 'sparse', 'colbert']:
if all_embeddings[key] is not None and len(all_embeddings[key]) > 0: # type: ignore[arg-type]
try:
embeddings = torch.cat(all_embeddings[key], dim=0) # type: ignore[arg-type]
if normalize_embeddings and key == 'dense':
embeddings = F.normalize(embeddings, p=2, dim=-1)
if convert_to_numpy:
embeddings = embeddings.numpy()
elif not convert_to_tensor:
embeddings = embeddings
final_embeddings[key] = embeddings
except Exception as e:
logger.error(f"Failed to process {key} embeddings: {e}")
raise
# Return format
if len(final_embeddings) == 1:
return list(final_embeddings.values())[0]
if not final_embeddings:
raise RuntimeError("No embeddings were generated")
return final_embeddings
except Exception as e:
logger.error(f"Encoding failed: {e}")
raise
def save_pretrained(self, save_path: str):
"""Save model to directory with error handling, including custom heads and config."""
try:
import os
import json
os.makedirs(save_path, exist_ok=True)
# Save the base model
self.model.save_pretrained(save_path)
# Save additional configuration (including [m3] section)
from utils.config_loader import get_config
config_dict = get_config().get_section('m3')
config = {
'use_dense': self.use_dense,
'use_sparse': self.use_sparse,
'use_colbert': self.use_colbert,
'pooling_method': self.pooling_method,
'normalize_embeddings': self.normalize_embeddings,
'colbert_dim': self.colbert_dim,
'sparse_top_k': self.sparse_top_k,
'model_name_or_path': self.model_name_or_path,
'm3_config': config_dict
}
with open(f"{save_path}/m3_config.json", 'w') as f:
json.dump(config, f, indent=2)
# Save custom heads if needed
if self.use_sparse:
torch.save(self.sparse_linear.state_dict(), os.path.join(save_path, 'sparse_linear.pt'))
if self.use_colbert:
torch.save(self.colbert_linear.state_dict(), os.path.join(save_path, 'colbert_linear.pt'))
# Save tokenizer if available
if self._tokenizer is not None:
self._tokenizer.save_pretrained(save_path)
logger.info(f"Model and config saved to {save_path}")
except Exception as e:
logger.error(f"Failed to save model: {e}")
raise
@classmethod
def from_pretrained(cls, model_path: str, **kwargs):
"""Load model from directory, restoring custom heads and config."""
import os
import json
try:
# Load config
config_path = os.path.join(model_path, "m3_config.json")
if os.path.exists(config_path):
with open(config_path, 'r') as f:
config = json.load(f)
else:
config = {}
# Instantiate model
model = cls(
model_name_or_path=config.get('model_name_or_path', model_path),
use_dense=config.get('use_dense', True),
use_sparse=config.get('use_sparse', True),
use_colbert=config.get('use_colbert', True),
pooling_method=config.get('pooling_method', 'cls'),
normalize_embeddings=config.get('normalize_embeddings', True),
colbert_dim=config.get('colbert_dim', -1),
sparse_top_k=config.get('sparse_top_k', 100),
**kwargs
)
# Load base model weights
model.model = model.model.from_pretrained(model_path)
# Load custom heads if present
if model.use_sparse:
sparse_path = os.path.join(model_path, 'sparse_linear.pt')
if os.path.exists(sparse_path):
model.sparse_linear.load_state_dict(torch.load(sparse_path, map_location='cpu'))
if model.use_colbert:
colbert_path = os.path.join(model_path, 'colbert_linear.pt')
if os.path.exists(colbert_path):
model.colbert_linear.load_state_dict(torch.load(colbert_path, map_location='cpu'))
logger.info(f"Loaded BGE-M3 model from {model_path}")
return model
except Exception as e:
logger.error(f"Failed to load model from {model_path}: {e}")
raise
def compute_similarity(
self,
queries: Union[torch.Tensor, Dict[str, torch.Tensor]],
passages: Union[torch.Tensor, Dict[str, torch.Tensor]],
mode: str = "dense"
) -> torch.Tensor:
"""
Compute similarity scores between queries and passages with error handling
Args:
queries: Query embeddings
passages: Passage embeddings
mode: Similarity mode ('dense', 'sparse', 'colbert')
Returns:
Similarity scores [num_queries, num_passages]
"""
try:
# Extract embeddings based on mode
if isinstance(queries, dict):
if mode not in queries:
raise ValueError(f"Mode '{mode}' not found in query embeddings")
query_emb = queries[mode]
else:
query_emb = queries
if isinstance(passages, dict):
if mode not in passages:
raise ValueError(f"Mode '{mode}' not found in passage embeddings")
passage_emb = passages[mode]
else:
passage_emb = passages
# Validate tensor shapes
if query_emb.dim() < 2 or passage_emb.dim() < 2:
raise ValueError("Embeddings must be at least 2-dimensional")
if mode == "dense":
# Simple dot product for dense
return torch.matmul(query_emb, passage_emb.t())
elif mode == "sparse":
# Dot product for sparse (simplified)
return torch.matmul(query_emb, passage_emb.t())
elif mode == "colbert":
# MaxSim for ColBERT
if query_emb.dim() != 3 or passage_emb.dim() != 3:
raise ValueError("ColBERT embeddings must be 3-dimensional [batch, seq_len, dim]")
scores = []
for q in query_emb:
# q: [query_len, dim]
# passage_emb: [num_passages, passage_len, dim]
passage_scores = []
for p in passage_emb:
# Compute similarity matrix
sim_matrix = torch.matmul(q, p.t()) # [query_len, passage_len]
# Max over passage tokens for each query token
max_sim_scores = sim_matrix.max(dim=1)[0] # [query_len]
# Sum over query tokens
score = max_sim_scores.sum()
passage_scores.append(score)
scores.append(torch.stack(passage_scores))
return torch.stack(scores)
else:
raise ValueError(f"Unknown similarity mode: {mode}")
except Exception as e:
logger.error(f"Similarity computation failed: {e}")
raise
def get_sentence_embedding_dimension(self) -> int:
"""Get the dimension of sentence embeddings"""
return self.config.hidden_size # type: ignore[attr-defined]
def get_word_embedding_dimension(self) -> int:
"""Get the dimension of word embeddings"""
return self.config.hidden_size # type: ignore[attr-defined]

483
models/bge_reranker.py Normal file
View File

@@ -0,0 +1,483 @@
"""
BGE-Reranker cross-encoder model wrapper
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoModel, AutoTokenizer, AutoConfig, AutoModelForSequenceClassification
from typing import Dict, List, Optional, Union, Tuple
import numpy as np
import logging
from dataclasses import dataclass
from utils.config_loader import get_config
logger = logging.getLogger(__name__)
@dataclass
class RerankerOutput:
"""Output class for reranker predictions"""
scores: torch.Tensor
logits: torch.Tensor
embeddings: Optional[torch.Tensor] = None
def load_reranker_model_and_tokenizer(model_name_or_path, cache_dir=None, num_labels=1):
"""Load reranker model and tokenizer from HuggingFace or ModelScope based on config.toml"""
config = get_config()
# Try to get model source from config, with fallback to huggingface
try:
model_source = config.get('model_paths', 'source', 'huggingface')
except Exception:
# Fallback if config loading fails
logger.warning("Failed to load model source from config, defaulting to 'huggingface'")
model_source = 'huggingface'
logger.info(f"Loading reranker model from source: {model_source}")
if model_source == 'huggingface':
from transformers import AutoModel, AutoModelForSequenceClassification, AutoTokenizer, AutoConfig
model_config = AutoConfig.from_pretrained(
model_name_or_path,
num_labels=num_labels,
cache_dir=cache_dir
)
try:
model = AutoModelForSequenceClassification.from_pretrained(
model_name_or_path,
config=model_config,
cache_dir=cache_dir
)
use_base_model = False
except:
model = AutoModel.from_pretrained(
model_name_or_path,
config=model_config,
cache_dir=cache_dir
)
use_base_model = True
tokenizer = AutoTokenizer.from_pretrained(
model_name_or_path,
cache_dir=cache_dir
)
return model, tokenizer, model_config, use_base_model
elif model_source == 'modelscope':
try:
from modelscope.models import Model as MSModel
from modelscope.hub.snapshot_download import snapshot_download as ms_snapshot_download
ms_snapshot_download(model_id=model_name_or_path, cache_dir=cache_dir)
model = MSModel.from_pretrained(model_name_or_path, cache_dir=cache_dir)
tokenizer = None # ModelScope tokenizer loading can be added if needed
model_config = None
use_base_model = False
return model, tokenizer, model_config, use_base_model
except ImportError:
logger.error("modelscope not available. Install it first: pip install modelscope")
raise
except Exception as e:
logger.error(f"Failed to load reranker model from ModelScope: {e}")
raise
else:
logger.error(f"Unknown model source: {model_source}. Supported: 'huggingface', 'modelscope'.")
raise RuntimeError(f"Unknown model source: {model_source}")
class BGERerankerModel(nn.Module):
"""
BGE-Reranker cross-encoder model for passage reranking
Notes:
- All parameters (model_name_or_path, num_labels, cache_dir, use_fp16, gradient_checkpointing, device, etc.) should be passed from config-driven training scripts.
- This class does not load config directly; pass config values from your main script for consistency.
"""
def __init__(
self,
model_name_or_path: str = "BAAI/bge-reranker-base",
num_labels: int = 1,
cache_dir: Optional[str] = None,
use_fp16: bool = False,
gradient_checkpointing: bool = False,
device: Optional[str] = None
):
"""
Initialize BGE-Reranker model.
Args:
model_name_or_path: Path or identifier for pretrained model
num_labels: Number of output labels
cache_dir: Directory for model cache
use_fp16: Use mixed precision
gradient_checkpointing: Enable gradient checkpointing
device: Device to use
"""
super().__init__()
self.model_name_or_path = model_name_or_path
self.num_labels = num_labels
self.use_fp16 = use_fp16
# Load configuration and model
try:
logger.info(f"Loading BGE-Reranker model from {model_name_or_path} using source from config.toml...")
self.model, self.tokenizer, self.config, self.use_base_model = load_reranker_model_and_tokenizer(
model_name_or_path,
cache_dir=cache_dir,
num_labels=num_labels
)
logger.info("BGE-Reranker model loaded successfully")
except Exception as e:
logger.error(f"Failed to load BGE-Reranker model: {e}")
raise RuntimeError(f"Model initialization failed: {e}")
# Add classification head
if self.use_base_model:
self.classifier = nn.Linear(self.config.hidden_size, num_labels) # type: ignore[attr-defined]
self.dropout = nn.Dropout(
self.config.hidden_dropout_prob if hasattr(self.config, 'hidden_dropout_prob') else 0.1) # type: ignore[attr-defined]
# Enable gradient checkpointing
if gradient_checkpointing:
try:
self.model.gradient_checkpointing_enable()
except Exception:
logger.warning("Gradient checkpointing requested but not supported by this model.")
# Set device
if device:
try:
self.device = torch.device(device)
self.to(self.device)
except Exception as e:
logger.warning(f"Failed to move model to {device}: {e}, using CPU")
self.device = torch.device('cpu')
self.to(self.device)
else:
self.device = next(self.parameters()).device
# Mixed precision
if self.use_fp16:
try:
self.half()
except Exception as e:
logger.warning(f"Failed to set model to half precision: {e}")
def forward(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
token_type_ids: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
return_dict: bool = True
) -> Union[torch.Tensor, RerankerOutput]:
"""
Forward pass for cross-encoder reranking
Args:
input_ids: Token IDs [batch_size, seq_length]
attention_mask: Attention mask [batch_size, seq_length]
token_type_ids: Token type IDs (optional)
labels: Labels for training (optional)
return_dict: Return RerankerOutput object
Returns:
Scores or RerankerOutput
"""
if self.use_base_model:
# Use base model + custom head
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
return_dict=True
)
# Get pooled output (CLS token)
if hasattr(outputs, 'pooler_output') and outputs.pooler_output is not None:
pooled_output = outputs.pooler_output
else:
# Manual pooling if pooler not available
last_hidden_state = outputs.last_hidden_state
pooled_output = last_hidden_state[:, 0]
# Apply dropout and classifier
pooled_output = self.dropout(pooled_output)
logits = self.classifier(pooled_output)
embeddings = pooled_output
else:
# Use sequence classification model
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
labels=labels,
return_dict=True
)
logits = outputs.logits
embeddings = None
# Convert logits to scores (sigmoid for binary classification)
if self.num_labels == 1:
scores = torch.sigmoid(logits.squeeze(-1))
else:
scores = F.softmax(logits, dim=-1)
if not return_dict:
return scores
return RerankerOutput(
scores=scores,
logits=logits,
embeddings=embeddings
)
def predict(
self,
queries: Union[str, List[str]],
passages: Union[str, List[str], List[List[str]]],
batch_size: int = 32,
max_length: int = 512,
show_progress_bar: bool = False,
convert_to_numpy: bool = True,
normalize_scores: bool = False
) -> Union[np.ndarray, torch.Tensor, List[float]]:
"""
Predict reranking scores for query-passage pairs
Args:
queries: Query or list of queries
passages: Passage(s) or list of passages per query
batch_size: Batch size for inference
max_length: Maximum sequence length
show_progress_bar: Show progress bar
convert_to_numpy: Convert to numpy array
normalize_scores: Normalize scores to [0, 1]
Returns:
Reranking scores
"""
# Handle input formats
if isinstance(queries, str):
queries = [queries]
if isinstance(passages, str):
passages = [[passages]]
elif isinstance(passages[0], str):
passages = [passages] # type: ignore[assignment]
# Ensure queries and passages match
if len(queries) == 1 and len(passages) > 1:
queries = queries * len(passages)
# Create pairs
pairs = []
for query, passage_list in zip(queries, passages):
if isinstance(passage_list, str):
passage_list = [passage_list]
for passage in passage_list:
pairs.append((query, passage))
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(self.model_name_or_path)
# Score in batches
all_scores = []
if show_progress_bar:
from tqdm import tqdm
pbar = tqdm(total=len(pairs), desc="Scoring")
for i in range(0, len(pairs), batch_size):
batch_pairs = pairs[i:i + batch_size]
# Tokenize
batch_queries = [pair[0] for pair in batch_pairs]
batch_passages = [pair[1] for pair in batch_pairs]
encoded = tokenizer(
batch_queries,
batch_passages,
padding=True,
truncation=True,
max_length=max_length,
return_tensors='pt'
)
# Move to device
encoded = {k: v.to(self.device) for k, v in encoded.items()}
# Get scores
with torch.no_grad():
outputs = self.forward(**encoded, return_dict=True)
scores = outputs.scores # type: ignore[attr-defined]
all_scores.append(scores)
if show_progress_bar:
pbar.update(len(batch_pairs))
if show_progress_bar:
pbar.close()
# Concatenate scores
all_scores = torch.cat(all_scores, dim=0)
# Normalize if requested
if normalize_scores:
all_scores = (all_scores - all_scores.min()) / (all_scores.max() - all_scores.min() + 1e-8)
# Convert format
if convert_to_numpy:
all_scores = all_scores.cpu().numpy()
elif not isinstance(all_scores, torch.Tensor):
all_scores = all_scores.cpu()
# Reshape if single query
if len(queries) == 1 and len(passages) == 1:
return all_scores.item() if convert_to_numpy else all_scores # type: ignore[attr-defined]
return all_scores
def rerank(
self,
query: str,
passages: List[str],
top_k: Optional[int] = None,
return_scores: bool = False,
batch_size: int = 32,
max_length: int = 512
) -> Union[List[str], Tuple[List[str], List[float]]]:
"""
Rerank passages for a query
Args:
query: Query string
passages: List of passages to rerank
top_k: Return only top-k passages
return_scores: Also return scores
batch_size: Batch size for scoring
max_length: Maximum sequence length
Returns:
Reranked passages (and optionally scores)
"""
# Get scores
scores = self.predict(
query,
passages,
batch_size=batch_size,
max_length=max_length,
convert_to_numpy=True
)
# Sort by scores (descending)
sorted_indices = np.argsort(scores)[::-1]
# Apply top_k if specified
if top_k is not None:
sorted_indices = sorted_indices[:top_k]
# Get reranked passages
reranked_passages = [passages[i] for i in sorted_indices]
if return_scores:
reranked_scores = [float(scores[i]) for i in sorted_indices]
return reranked_passages, reranked_scores
return reranked_passages
def save_pretrained(self, save_path: str):
"""Save model to directory, including custom heads and config."""
try:
import os
import json
os.makedirs(save_path, exist_ok=True)
# Save the base model
self.model.save_pretrained(save_path)
# Save additional configuration (including [reranker] section)
from utils.config_loader import get_config
config_dict = get_config().get_section('reranker')
config = {
'use_fp16': self.use_fp16 if hasattr(self, 'use_fp16') else False,
'num_labels': getattr(self, 'num_labels', 1),
'reranker_config': config_dict
}
with open(f"{save_path}/reranker_config.json", 'w') as f:
json.dump(config, f, indent=2)
# Save custom heads if present
if hasattr(self, 'classifier'):
torch.save(self.classifier.state_dict(), os.path.join(save_path, 'classifier.pt'))
if hasattr(self, 'dropout'):
torch.save(self.dropout.state_dict(), os.path.join(save_path, 'dropout.pt'))
# Save tokenizer if available
if hasattr(self, '_tokenizer') and self._tokenizer is not None:
self._tokenizer.save_pretrained(save_path) # type: ignore[attr-defined]
logger.info(f"Reranker model and config saved to {save_path}")
except Exception as e:
logger.error(f"Failed to save reranker model: {e}")
raise
@classmethod
def from_pretrained(cls, model_path: str, **kwargs):
"""Load model from directory, restoring custom heads and config."""
import os
import json
try:
# Load config
config_path = os.path.join(model_path, "reranker_config.json")
if os.path.exists(config_path):
with open(config_path, 'r') as f:
config = json.load(f)
else:
config = {}
# Instantiate model
model = cls(
model_name_or_path=model_path,
num_labels=config.get('num_labels', 1),
use_fp16=config.get('use_fp16', False),
**kwargs
)
# Load base model weights
model.model = model.model.from_pretrained(model_path)
# Load custom heads if present
classifier_path = os.path.join(model_path, 'classifier.pt')
if hasattr(model, 'classifier') and os.path.exists(classifier_path):
model.classifier.load_state_dict(torch.load(classifier_path, map_location='cpu'))
dropout_path = os.path.join(model_path, 'dropout.pt')
if hasattr(model, 'dropout') and os.path.exists(dropout_path):
model.dropout.load_state_dict(torch.load(dropout_path, map_location='cpu'))
logger.info(f"Loaded BGE-Reranker model from {model_path}")
return model
except Exception as e:
logger.error(f"Failed to load reranker model from {model_path}: {e}")
raise
def compute_loss(
self,
scores: torch.Tensor,
labels: torch.Tensor,
loss_type: str = "binary_cross_entropy"
) -> torch.Tensor:
"""
Compute loss for training
Args:
scores: Model scores
labels: Ground truth labels
loss_type: Type of loss function
Returns:
Loss value
"""
if loss_type == "binary_cross_entropy":
return F.binary_cross_entropy(scores, labels.float())
elif loss_type == "cross_entropy":
return F.cross_entropy(scores, labels.long())
elif loss_type == "mse":
return F.mse_loss(scores, labels.float())
else:
raise ValueError(f"Unknown loss type: {loss_type}")

507
models/losses.py Normal file
View File

@@ -0,0 +1,507 @@
"""
Loss functions for BGE fine-tuning
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Dict, List, Optional, Tuple
import numpy as np
class InfoNCELoss(nn.Module):
"""
InfoNCE loss for contrastive learning
As used in BGE-M3 paper with temperature τ=0.02
Notes:
- All parameters (temperature, reduction, etc.) should be passed from config-driven training scripts.
- This class does not load config directly; pass config values from your main script for consistency.
"""
def __init__(self, temperature: float = 0.02, reduction: str = 'mean'):
super().__init__()
self.temperature = temperature
self.reduction = reduction
self.cross_entropy = nn.CrossEntropyLoss(reduction=reduction)
def forward(
self,
query_embeddings: torch.Tensor,
positive_embeddings: torch.Tensor,
negative_embeddings: Optional[torch.Tensor] = None
) -> torch.Tensor:
"""
Compute InfoNCE loss with proper normalization and numerical stability.
Args:
query_embeddings: [batch_size, embedding_dim]
positive_embeddings: [batch_size, embedding_dim]
negative_embeddings: [batch_size, num_negatives, embedding_dim] or None
"""
batch_size = query_embeddings.size(0)
device = query_embeddings.device
# Normalize embeddings (L2 normalization)
query_embeddings = F.normalize(query_embeddings, p=2, dim=1)
positive_embeddings = F.normalize(positive_embeddings, p=2, dim=1)
# Compute positive scores (query-positive similarity)
# Shape: [batch_size]
pos_scores = torch.sum(query_embeddings * positive_embeddings, dim=1)
pos_logits = pos_scores / self.temperature
# Compute negative scores if provided
if negative_embeddings is not None and negative_embeddings.numel() > 0:
# Normalize negative embeddings
negative_embeddings = F.normalize(negative_embeddings, p=2, dim=-1)
# Compute negative scores
# Shape: [batch_size, num_negatives]
neg_scores = torch.matmul(
query_embeddings.unsqueeze(1),
negative_embeddings.transpose(1, 2)
).squeeze(1)
neg_logits = neg_scores / self.temperature
# Concatenate positive and negative logits
# Shape: [batch_size, 1 + num_negatives]
logits = torch.cat([pos_logits.unsqueeze(1), neg_logits], dim=1)
# Labels: positive is always at index 0
labels = torch.zeros(batch_size, device=device, dtype=torch.long)
else:
# Use in-batch negatives as fallback
# Compute similarity matrix for in-batch negatives
similarity_matrix = torch.matmul(query_embeddings, positive_embeddings.t())
logits = similarity_matrix / self.temperature
# Labels: positive pairs are on the diagonal
labels = torch.arange(batch_size, device=device, dtype=torch.long)
# Compute cross-entropy loss
loss = self.cross_entropy(logits, labels)
return loss
class M3KnowledgeDistillationLoss(nn.Module):
"""
Knowledge distillation loss for BGE-M3's multi-functionality training
Distills knowledge from dense, sparse, and multi-vector representations
Notes:
- All parameters (temperature, alpha, distill_types, etc.) should be passed from config-driven training scripts.
- This class does not load config directly; pass config values from your main script for consistency.
"""
def __init__(
self,
temperature: float = 4.0,
alpha: float = 0.5,
distill_types: List[str] = ['dense', 'sparse', 'colbert']
):
super().__init__()
self.temperature = temperature
self.alpha = alpha
self.distill_types = distill_types
self.kl_div = nn.KLDivLoss(reduction='batchmean')
def forward(
self,
student_outputs: Dict[str, torch.Tensor],
teacher_outputs: Dict[str, torch.Tensor],
labels: Optional[torch.Tensor] = None
) -> Tuple[torch.Tensor, Dict[str, float]]:
"""
Args:
student_outputs: Dictionary with 'dense', 'sparse', 'colbert' outputs
teacher_outputs: Dictionary with teacher model outputs
labels: Optional labels for supervised loss
Returns:
total_loss: Combined distillation loss
loss_dict: Dictionary of individual losses
"""
total_loss = torch.tensor(0.0, device=next(iter(student_outputs.values())).device if student_outputs else 'cpu')
loss_dict = {}
# Distillation loss for each representation type
for repr_type in self.distill_types:
if repr_type in student_outputs and repr_type in teacher_outputs:
student_logits = student_outputs[repr_type]
teacher_logits = teacher_outputs[repr_type]
# Apply temperature scaling
student_log_probs = F.log_softmax(student_logits / self.temperature, dim=-1)
teacher_probs = F.softmax(teacher_logits / self.temperature, dim=-1)
# KL divergence loss
kl_loss = self.kl_div(student_log_probs, teacher_probs) * (self.temperature ** 2)
total_loss = total_loss + self.alpha * kl_loss
loss_dict[f'{repr_type}_distill_loss'] = kl_loss.detach().cpu().item()
# Add supervised loss if labels provided
if labels is not None and 'dense' in student_outputs:
ce_loss = F.cross_entropy(student_outputs['dense'], labels)
total_loss = total_loss + (1 - self.alpha) * ce_loss
loss_dict['supervised_loss'] = ce_loss.detach().cpu().item()
loss_dict['total_loss'] = total_loss.detach().cpu().item()
return total_loss, loss_dict
class ListwiseCrossEntropyLoss(nn.Module):
"""
Listwise cross-entropy loss for reranker training
Notes:
- All parameters (temperature, etc.) should be passed from config-driven training scripts.
- This class does not load config directly; pass config values from your main script for consistency.
"""
def __init__(self, temperature: float = 1.0):
super().__init__()
self.temperature = temperature
def forward(self, scores: torch.Tensor, labels: torch.Tensor, groups: List[int]) -> torch.Tensor:
"""
Compute listwise cross-entropy loss with proper normalization.
Args:
scores: Relevance scores [total_passages]
labels: Binary relevance labels [total_passages]
groups: List of group sizes (passages per query)
"""
total_loss = torch.tensor(0.0, device=scores.device, dtype=scores.dtype)
offset = 0
for group_size in groups:
if group_size == 0:
continue
# Get scores and labels for this group
group_scores = scores[offset:offset + group_size]
group_labels = labels[offset:offset + group_size]
# Skip if no positive labels
if group_labels.sum() == 0:
offset += group_size
continue
# Apply temperature scaling
scaled_scores = group_scores / self.temperature
# Apply log_softmax for numerical stability
log_probs = F.log_softmax(scaled_scores, dim=0)
# Normalize labels to create a probability distribution
label_probs = group_labels.float() / group_labels.sum()
# Compute cross-entropy: -sum(p_true * log(p_pred))
loss = -(label_probs * log_probs).sum()
total_loss = total_loss + loss
offset += group_size
# Average over number of groups
num_valid_groups = sum(1 for g in groups if g > 0)
if num_valid_groups > 0:
return total_loss / num_valid_groups
else:
return total_loss
class PairwiseMarginLoss(nn.Module):
"""
Pairwise margin ranking loss for reranker
Notes:
- All parameters (margin, etc.) should be passed from config-driven training scripts.
- This class does not load config directly; pass config values from your main script for consistency.
"""
def __init__(self, margin: float = 1.0):
super().__init__()
self.margin = margin
self.ranking_loss = nn.MarginRankingLoss(margin=margin)
def forward(self, scores: torch.Tensor, labels: torch.Tensor, groups: List[int]) -> torch.Tensor:
"""
Args:
scores: Relevance scores [total_passages]
labels: Binary relevance labels [total_passages]
groups: List of group sizes
"""
total_loss = torch.tensor(0.0, device=scores.device)
num_pairs = 0
offset = 0
for group_size in groups:
group_scores = scores[offset:offset + group_size]
group_labels = labels[offset:offset + group_size]
# Find positive and negative indices
pos_indices = torch.where(group_labels == 1)[0]
neg_indices = torch.where(group_labels == 0)[0]
if len(pos_indices) > 0 and len(neg_indices) > 0:
# Create all positive-negative pairs
for pos_idx in pos_indices:
for neg_idx in neg_indices:
pos_score = group_scores[pos_idx].unsqueeze(0)
neg_score = group_scores[neg_idx].unsqueeze(0)
target = torch.ones(1, device=scores.device)
loss = self.ranking_loss(pos_score, neg_score, target)
total_loss = total_loss + loss
num_pairs += 1
offset += group_size
return total_loss / max(num_pairs, 1)
class DynamicListwiseDistillationLoss(nn.Module):
"""
Dynamic listwise distillation loss for RocketQAv2-style joint training
Notes:
- All parameters (temperature, dynamic_weight, etc.) should be passed from config-driven training scripts.
- This class does not load config directly; pass config values from your main script for consistency.
"""
def __init__(self, temperature: float = 3.0, dynamic_weight: bool = True):
super().__init__()
self.temperature = temperature
self.dynamic_weight = dynamic_weight
self.kl_div = nn.KLDivLoss(reduction='none')
def forward(
self,
student_scores: torch.Tensor,
teacher_scores: torch.Tensor,
groups: List[int],
confidence_scores: Optional[torch.Tensor] = None
) -> torch.Tensor:
"""
Args:
student_scores: Student model scores [total_passages]
teacher_scores: Teacher model scores [total_passages]
groups: List of group sizes
confidence_scores: Optional confidence for dynamic weighting
"""
total_loss = torch.tensor(0.0, device=student_scores.device)
offset = 0
for i, group_size in enumerate(groups):
student_group = student_scores[offset:offset + group_size] / self.temperature
teacher_group = teacher_scores[offset:offset + group_size] / self.temperature
# Convert to distributions
student_probs = F.log_softmax(student_group, dim=0)
teacher_probs = F.softmax(teacher_group, dim=0)
# Compute KL divergence
kl_loss = self.kl_div(student_probs, teacher_probs).sum()
# Apply dynamic weighting based on confidence
if self.dynamic_weight and confidence_scores is not None:
weight = confidence_scores[i]
kl_loss = kl_loss * weight
total_loss = total_loss + kl_loss * (self.temperature ** 2)
offset += group_size
return total_loss / len(groups)
class MultiTaskLoss(nn.Module):
"""
Multi-task loss for combining multiple objectives
Notes:
- All parameters (loss_weights, adaptive_weights, etc.) should be passed from config-driven training scripts.
- This class does not load config directly; pass config values from your main script for consistency.
"""
def __init__(
self,
loss_weights: Optional[Dict[str, float]] = None,
adaptive_weights: bool = False
):
super().__init__()
self.loss_weights = loss_weights or {
'dense': 1.0,
'sparse': 0.3,
'colbert': 0.5,
'distillation': 1.0
}
self.adaptive_weights = adaptive_weights
if adaptive_weights:
# Learnable loss weights (uncertainty weighting)
self.log_vars = nn.Parameter(torch.zeros(len(self.loss_weights)))
def forward(self, losses: Dict[str, torch.Tensor]) -> Tuple[torch.Tensor, Dict[str, float]]:
"""
Args:
losses: Dictionary of individual losses
Returns:
total_loss: Combined loss
loss_dict: Dictionary of individual loss values
"""
total_loss = torch.tensor(0.0, device=next(iter(losses.values())).device if losses else 'cpu')
loss_dict = {}
if self.adaptive_weights:
# Uncertainty-based weighting
for i, (loss_name, loss_value) in enumerate(losses.items()):
if loss_name in self.loss_weights:
precision = torch.exp(-self.log_vars[i])
weighted_loss = precision * loss_value + self.log_vars[i]
total_loss = total_loss + weighted_loss
loss_dict[f'{loss_name}_loss'] = loss_value.detach().cpu().item() if isinstance(loss_value, torch.Tensor) else float(loss_value)
loss_dict[f'{loss_name}_weight'] = precision.detach().cpu().item() if isinstance(precision, torch.Tensor) else float(precision)
else:
# Fixed weighting
for loss_name, loss_value in losses.items():
if loss_name in self.loss_weights:
weight = self.loss_weights[loss_name]
weighted_loss = weight * loss_value
total_loss = total_loss + weighted_loss
loss_dict[f'{loss_name}_loss'] = loss_value.detach().cpu().item() if isinstance(loss_value, torch.Tensor) else float(loss_value)
loss_dict[f'{loss_name}_weight'] = weight
loss_dict['total_loss'] = total_loss.detach().cpu().item() if isinstance(total_loss, torch.Tensor) else float(total_loss)
return total_loss, loss_dict
class SparseLoss(nn.Module):
"""
SPLADE+ sparse loss with log1p(ReLU(.)) and FLOPS regularization.
For BGE-M3's token-position based sparse representations.
Args:
flops_loss_weight: Weight for FLOPS regularization
sparse_reg_weight: Weight for sparsity regularization
temperature: Temperature for similarity computation
"""
def __init__(
self,
flops_loss_weight: float = 1e-3,
sparse_reg_weight: float = 1e-4,
temperature: float = 0.02
):
super().__init__()
self.flops_loss_weight = flops_loss_weight
self.sparse_reg_weight = sparse_reg_weight
self.temperature = temperature
self.cross_entropy = nn.CrossEntropyLoss()
def forward(
self,
query_sparse: torch.Tensor,
doc_sparse: torch.Tensor,
labels: torch.Tensor
) -> Tuple[torch.Tensor, Dict[str, float]]:
"""
Args:
query_sparse: [batch, query_seq_len] - Query sparse embeddings (token-position based)
doc_sparse: [batch, doc_seq_len] - Document sparse embeddings (token-position based)
labels: [batch] - Labels for contrastive learning
Returns:
Tuple of (loss, loss_dict)
"""
batch_size = query_sparse.size(0)
device = query_sparse.device
# SPLADE+ uses log1p(ReLU(.))
query_splade = torch.log1p(F.relu(query_sparse))
doc_splade = torch.log1p(F.relu(doc_sparse))
# For BGE-M3's token-position sparse embeddings, compute similarity using aggregated scores
# Since we have different sequence lengths, we need to aggregate to comparable representations
# Sum pooling to get document-level sparse scores
query_scores = query_splade.sum(dim=1, keepdim=True) # [batch, 1]
doc_scores = doc_splade.sum(dim=1, keepdim=True) # [batch, 1]
# Compute similarity matrix for contrastive learning
# Use broadcasting to create [batch, batch] similarity matrix
similarity_matrix = query_scores * doc_scores.t() # [batch, batch]
# Alternative: use cosine similarity approach
# Normalize the aggregated scores
query_norm = F.normalize(query_scores, p=2, dim=1)
doc_norm = F.normalize(doc_scores, p=2, dim=1)
similarity_matrix = torch.matmul(query_norm, doc_norm.t())
# Scale by temperature for contrastive loss
logits = similarity_matrix / self.temperature
# Labels for in-batch contrastive learning (diagonal elements are positive pairs)
contrastive_labels = torch.arange(batch_size, device=device, dtype=torch.long)
# Contrastive loss
main_loss = self.cross_entropy(logits, contrastive_labels)
# FLOPS regularization (encourage sparsity)
flops = (query_splade.sum(dim=-1) + doc_splade.sum(dim=-1)).mean()
# L1 sparsity regularization
sparsity_loss = (query_splade.sum() + doc_splade.sum()) / (query_splade.numel() + doc_splade.numel())
# Combine losses
total_loss = main_loss + self.flops_loss_weight * flops + self.sparse_reg_weight * sparsity_loss
# Ensure all losses are tensors
if not isinstance(total_loss, torch.Tensor):
total_loss = torch.tensor(total_loss, device=device)
if not isinstance(main_loss, torch.Tensor):
main_loss = torch.tensor(main_loss, device=device)
if not isinstance(flops, torch.Tensor):
flops = torch.tensor(flops, device=device)
if not isinstance(sparsity_loss, torch.Tensor):
sparsity_loss = torch.tensor(sparsity_loss, device=device)
loss_dict = {
'main_loss': main_loss.detach().cpu().item(),
'flops': flops.detach().cpu().item(),
'sparsity_loss': sparsity_loss.detach().cpu().item(),
'total_loss': total_loss.detach().cpu().item()
}
return total_loss, loss_dict
class ColBERTLoss(nn.Module):
"""
ColBERT late interaction (maxsim) loss.
Args:
temperature: Softmax temperature
"""
def __init__(self, temperature: float = 0.02):
super().__init__()
self.temperature = temperature
def forward(
self,
query_embeds: torch.Tensor,
doc_embeds: torch.Tensor,
labels: torch.Tensor
) -> torch.Tensor:
"""
Args:
query_embeds: [batch, seq_len, dim]
doc_embeds: [batch, seq_len, dim]
labels: [batch]
Returns:
Loss value
"""
# Late interaction: maxsim over doc tokens for each query token
# [batch, query_len, doc_len]
sim_matrix = torch.einsum('bqd,bkd->bqk', query_embeds, doc_embeds) / self.temperature
maxsim = sim_matrix.max(dim=2)[0] # [batch, query_len]
# Aggregate over query tokens (mean)
scores = maxsim.mean(dim=1) # [batch]
# Binary cross-entropy loss
loss = F.binary_cross_entropy_with_logits(scores, labels.float())
return loss