init commit
This commit is contained in:
25
models/__init__.py
Normal file
25
models/__init__.py
Normal file
@@ -0,0 +1,25 @@
|
||||
"""
|
||||
Model implementations for BGE fine-tuning
|
||||
"""
|
||||
from .bge_m3 import BGEM3Model
|
||||
from .bge_reranker import BGERerankerModel
|
||||
from .losses import (
|
||||
InfoNCELoss,
|
||||
M3KnowledgeDistillationLoss,
|
||||
ListwiseCrossEntropyLoss,
|
||||
PairwiseMarginLoss,
|
||||
DynamicListwiseDistillationLoss,
|
||||
MultiTaskLoss
|
||||
)
|
||||
|
||||
|
||||
__all__ = [
|
||||
'BGEM3Model',
|
||||
'BGERerankerModel',
|
||||
'InfoNCELoss',
|
||||
'M3KnowledgeDistillationLoss',
|
||||
'ListwiseCrossEntropyLoss',
|
||||
'PairwiseMarginLoss',
|
||||
'DynamicListwiseDistillationLoss',
|
||||
'MultiTaskLoss'
|
||||
]
|
||||
768
models/bge_m3.py
Normal file
768
models/bge_m3.py
Normal file
@@ -0,0 +1,768 @@
|
||||
"""
|
||||
BGE-M3 model wrapper for multi-functionality embedding
|
||||
"""
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from transformers import AutoModel, AutoTokenizer, AutoConfig
|
||||
from typing import Dict, List, Optional, Union, Tuple
|
||||
import numpy as np
|
||||
import logging
|
||||
from utils.config_loader import get_config
|
||||
import os
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def load_model_and_tokenizer(model_name_or_path, cache_dir=None):
|
||||
"""
|
||||
Load model and tokenizer with comprehensive error handling and fallback strategies
|
||||
|
||||
Args:
|
||||
model_name_or_path: Model identifier or local path
|
||||
cache_dir: Cache directory for downloaded models
|
||||
|
||||
Returns:
|
||||
Tuple of (model, tokenizer, config)
|
||||
|
||||
Raises:
|
||||
RuntimeError: If model loading fails after all attempts
|
||||
"""
|
||||
config = get_config()
|
||||
model_source = config.get('model_paths', 'source', 'huggingface')
|
||||
|
||||
# Track loading attempts for debugging
|
||||
attempts = []
|
||||
|
||||
# Try primary loading method
|
||||
try:
|
||||
if model_source == 'huggingface':
|
||||
return _load_from_huggingface(model_name_or_path, cache_dir)
|
||||
elif model_source == 'modelscope':
|
||||
return _load_from_modelscope(model_name_or_path, cache_dir)
|
||||
else:
|
||||
raise ValueError(f"Unknown model source: {model_source}")
|
||||
except Exception as e:
|
||||
attempts.append(f"{model_source}: {str(e)}")
|
||||
logger.warning(f"Primary loading from {model_source} failed: {e}")
|
||||
|
||||
# Fallback strategies
|
||||
fallback_sources = ['huggingface'] if model_source == 'modelscope' else ['modelscope']
|
||||
|
||||
for fallback in fallback_sources:
|
||||
try:
|
||||
logger.info(f"Attempting fallback loading from {fallback}")
|
||||
if fallback == 'huggingface':
|
||||
return _load_from_huggingface(model_name_or_path, cache_dir)
|
||||
elif fallback == 'modelscope':
|
||||
return _load_from_modelscope(model_name_or_path, cache_dir)
|
||||
except Exception as e:
|
||||
attempts.append(f"{fallback}: {str(e)}")
|
||||
logger.warning(f"Fallback loading from {fallback} failed: {e}")
|
||||
|
||||
# Try loading from local path if it exists
|
||||
if os.path.exists(model_name_or_path):
|
||||
try:
|
||||
logger.info(f"Attempting to load from local path: {model_name_or_path}")
|
||||
return _load_from_local(model_name_or_path, cache_dir)
|
||||
except Exception as e:
|
||||
attempts.append(f"local: {str(e)}")
|
||||
logger.warning(f"Local loading failed: {e}")
|
||||
|
||||
# All attempts failed
|
||||
error_msg = f"Failed to load model from any source. Attempts:\n" + "\n".join(attempts)
|
||||
logger.error(error_msg)
|
||||
raise RuntimeError(error_msg)
|
||||
|
||||
|
||||
def _load_from_huggingface(model_name_or_path, cache_dir=None):
|
||||
"""Load model from HuggingFace with error handling"""
|
||||
try:
|
||||
from transformers import AutoModel, AutoTokenizer, AutoConfig
|
||||
|
||||
# Load config first to validate model
|
||||
model_config = AutoConfig.from_pretrained(
|
||||
model_name_or_path,
|
||||
cache_dir=cache_dir,
|
||||
trust_remote_code=True
|
||||
)
|
||||
|
||||
# Load model
|
||||
model = AutoModel.from_pretrained(
|
||||
model_name_or_path,
|
||||
config=model_config,
|
||||
cache_dir=cache_dir,
|
||||
trust_remote_code=True
|
||||
)
|
||||
|
||||
# Load tokenizer
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_name_or_path,
|
||||
cache_dir=cache_dir,
|
||||
trust_remote_code=True
|
||||
)
|
||||
|
||||
logger.info(f"Successfully loaded model from HuggingFace: {model_name_or_path}")
|
||||
return model, tokenizer, model_config
|
||||
|
||||
except ImportError as e:
|
||||
raise ImportError(f"Transformers library not available: {e}")
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"HuggingFace loading error: {e}")
|
||||
|
||||
|
||||
def _load_from_modelscope(model_name_or_path, cache_dir=None):
|
||||
"""Load model from ModelScope with error handling"""
|
||||
try:
|
||||
from modelscope.models import Model as MSModel
|
||||
from modelscope.hub.snapshot_download import snapshot_download as ms_snapshot_download
|
||||
|
||||
# Download model if needed
|
||||
local_path = ms_snapshot_download(model_id=model_name_or_path, cache_dir=cache_dir)
|
||||
|
||||
# Load model
|
||||
model = MSModel.from_pretrained(local_path)
|
||||
|
||||
# Try to load tokenizer
|
||||
tokenizer = None
|
||||
try:
|
||||
from transformers import AutoTokenizer
|
||||
tokenizer = AutoTokenizer.from_pretrained(local_path)
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not load tokenizer from ModelScope: {e}")
|
||||
|
||||
# Try to get config
|
||||
config = None
|
||||
config_path = os.path.join(local_path, 'config.json')
|
||||
if os.path.exists(config_path):
|
||||
import json
|
||||
with open(config_path, 'r') as f:
|
||||
config = json.load(f)
|
||||
|
||||
logger.info(f"Successfully loaded model from ModelScope: {model_name_or_path}")
|
||||
return model, tokenizer, config
|
||||
|
||||
except ImportError as e:
|
||||
raise ImportError(f"ModelScope library not available: {e}")
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"ModelScope loading error: {e}")
|
||||
|
||||
|
||||
def _load_from_local(model_path, cache_dir=None):
|
||||
"""Load model from local directory with error handling"""
|
||||
try:
|
||||
from transformers import AutoModel, AutoTokenizer, AutoConfig
|
||||
|
||||
if not os.path.isdir(model_path):
|
||||
raise ValueError(f"Local path is not a directory: {model_path}")
|
||||
|
||||
# Check for required files
|
||||
required_files = ['config.json']
|
||||
missing_files = [f for f in required_files if not os.path.exists(os.path.join(model_path, f))]
|
||||
if missing_files:
|
||||
raise FileNotFoundError(f"Missing required files: {missing_files}")
|
||||
|
||||
# Load components
|
||||
model_config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
|
||||
model = AutoModel.from_pretrained(model_path, config=model_config, trust_remote_code=True)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
||||
|
||||
logger.info(f"Successfully loaded model from local path: {model_path}")
|
||||
return model, tokenizer, model_config
|
||||
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Local loading error: {e}")
|
||||
|
||||
|
||||
class BGEM3Model(nn.Module):
|
||||
"""
|
||||
BGE-M3 model wrapper supporting dense, sparse, and multi-vector retrieval
|
||||
|
||||
Notes:
|
||||
- All parameters (model_name_or_path, use_dense, use_sparse, use_colbert, pooling_method, etc.) should be passed from config-driven training scripts.
|
||||
- This class does not load config directly; pass config values from your main script for consistency.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name_or_path: str = "BAAI/bge-m3",
|
||||
use_dense: bool = True,
|
||||
use_sparse: bool = True,
|
||||
use_colbert: bool = True,
|
||||
pooling_method: str = "cls",
|
||||
normalize_embeddings: bool = True,
|
||||
sentence_pooling_method: str = "cls",
|
||||
colbert_dim: int = -1,
|
||||
sparse_top_k: int = 100,
|
||||
cache_dir: Optional[str] = None,
|
||||
device: Optional[str] = None,
|
||||
gradient_checkpointing: bool = False
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.model_name_or_path = model_name_or_path
|
||||
self.use_dense = use_dense
|
||||
self.use_sparse = use_sparse
|
||||
self.use_colbert = use_colbert
|
||||
self.pooling_method = pooling_method
|
||||
self.normalize_embeddings = normalize_embeddings
|
||||
self.colbert_dim = colbert_dim
|
||||
self.sparse_top_k = sparse_top_k
|
||||
self.gradient_checkpointing = gradient_checkpointing
|
||||
|
||||
# Initialize model with comprehensive error handling
|
||||
try:
|
||||
logger.info(f"Loading BGE-M3 model from {model_name_or_path} using source from config.toml...")
|
||||
self.model, self.tokenizer, self.config = load_model_and_tokenizer(
|
||||
model_name_or_path,
|
||||
cache_dir=cache_dir
|
||||
)
|
||||
logger.info("BGE-M3 model loaded successfully")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load BGE-M3 model: {e}")
|
||||
raise RuntimeError(f"Model initialization failed: {e}")
|
||||
|
||||
# Enable gradient checkpointing if requested
|
||||
if self.gradient_checkpointing:
|
||||
self.enable_gradient_checkpointing()
|
||||
|
||||
# Ensure we have a valid config
|
||||
if self.config is None:
|
||||
logger.warning("Model config is None, creating a basic config")
|
||||
# Create a basic config object
|
||||
class BasicConfig:
|
||||
def __init__(self):
|
||||
self.hidden_size = 768 # Default BERT hidden size
|
||||
self.config = BasicConfig()
|
||||
|
||||
# Initialize heads for different functionalities
|
||||
hidden_size = getattr(self.config, 'hidden_size', 768)
|
||||
|
||||
if self.use_dense:
|
||||
# Dense retrieval doesn't need additional layers (uses CLS token)
|
||||
self.dense_linear = nn.Identity()
|
||||
|
||||
if self.use_sparse:
|
||||
# Sparse retrieval head (SPLADE-style)
|
||||
self.sparse_linear = nn.Linear(hidden_size, 1)
|
||||
self.sparse_activation = nn.ReLU()
|
||||
|
||||
if self.use_colbert:
|
||||
# ColBERT projection layer
|
||||
colbert_dim = colbert_dim if colbert_dim > 0 else hidden_size
|
||||
self.colbert_linear = nn.Linear(hidden_size, colbert_dim)
|
||||
|
||||
# Set device with enhanced detection
|
||||
if device:
|
||||
self.device = torch.device(device)
|
||||
else:
|
||||
self.device = self._detect_device()
|
||||
|
||||
try:
|
||||
self.to(self.device)
|
||||
logger.info(f"BGE-M3 model moved to device: {self.device}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to move model to {self.device}: {e}, using CPU")
|
||||
self.device = torch.device('cpu')
|
||||
self.to(self.device)
|
||||
|
||||
# Cache for tokenizer
|
||||
self._tokenizer = None
|
||||
|
||||
def enable_gradient_checkpointing(self):
|
||||
"""Enable gradient checkpointing for memory efficiency"""
|
||||
if hasattr(self.model, 'gradient_checkpointing_enable'):
|
||||
self.model.gradient_checkpointing_enable()
|
||||
logger.info("Gradient checkpointing enabled for BGE-M3")
|
||||
else:
|
||||
logger.warning("Model does not support gradient checkpointing")
|
||||
|
||||
def disable_gradient_checkpointing(self):
|
||||
"""Disable gradient checkpointing"""
|
||||
if hasattr(self.model, 'gradient_checkpointing_disable'):
|
||||
self.model.gradient_checkpointing_disable()
|
||||
logger.info("Gradient checkpointing disabled for BGE-M3")
|
||||
|
||||
def _detect_device(self) -> torch.device:
|
||||
"""Enhanced device detection with proper error handling"""
|
||||
# Check Ascend NPU first
|
||||
try:
|
||||
import torch_npu
|
||||
if hasattr(torch, 'npu') and torch.npu.is_available(): # type: ignore[attr-defined]
|
||||
logger.info("Using Ascend NPU")
|
||||
return torch.device("npu")
|
||||
except ImportError:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.debug(f"NPU detection failed: {e}")
|
||||
pass
|
||||
|
||||
# Check CUDA
|
||||
if torch.cuda.is_available():
|
||||
logger.info("Using CUDA GPU")
|
||||
return torch.device("cuda")
|
||||
|
||||
# Check MPS (Apple Silicon)
|
||||
try:
|
||||
if hasattr(torch, 'backends') and hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): # type: ignore[attr-defined]
|
||||
logger.info("Using Apple MPS")
|
||||
return torch.device("mps")
|
||||
except Exception as e:
|
||||
logger.debug(f"MPS detection failed: {e}")
|
||||
pass
|
||||
|
||||
# Default to CPU
|
||||
logger.info("Using CPU")
|
||||
return torch.device("cpu")
|
||||
|
||||
def _get_tokenizer(self):
|
||||
"""Lazy load tokenizer with error handling"""
|
||||
if self._tokenizer is None:
|
||||
try:
|
||||
self._tokenizer = AutoTokenizer.from_pretrained(
|
||||
self.model_name_or_path,
|
||||
trust_remote_code=True
|
||||
)
|
||||
logger.info("Tokenizer loaded successfully")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load tokenizer: {e}")
|
||||
raise RuntimeError(f"Tokenizer loading failed: {e}")
|
||||
return self._tokenizer
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
attention_mask: torch.Tensor,
|
||||
token_type_ids: Optional[torch.Tensor] = None,
|
||||
return_dense: bool = True,
|
||||
return_sparse: bool = False,
|
||||
return_colbert: bool = False,
|
||||
return_dict: bool = True
|
||||
) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
|
||||
"""
|
||||
Forward pass for BGE-M3 with dense, sparse (SPLADE+), and ColBERT heads.
|
||||
Args:
|
||||
input_ids: Token IDs [batch_size, seq_length]
|
||||
attention_mask: Attention mask [batch_size, seq_length]
|
||||
token_type_ids: Token type IDs (optional)
|
||||
return_dense: Return dense embeddings
|
||||
return_sparse: Return sparse embeddings
|
||||
return_colbert: Return ColBERT embeddings
|
||||
return_dict: Return dictionary of outputs
|
||||
Returns:
|
||||
Embeddings or dictionary of embeddings
|
||||
"""
|
||||
# Validate inputs
|
||||
if input_ids is None or attention_mask is None:
|
||||
raise ValueError("input_ids and attention_mask are required")
|
||||
if input_ids.shape != attention_mask.shape:
|
||||
raise ValueError("input_ids and attention_mask must have the same shape")
|
||||
|
||||
# Determine what to return based on model configuration and parameters
|
||||
return_dense = return_dense and self.use_dense
|
||||
return_sparse = return_sparse and self.use_sparse
|
||||
return_colbert = return_colbert and self.use_colbert
|
||||
|
||||
if not (return_dense or return_sparse or return_colbert):
|
||||
logger.warning("No output head selected in forward. At least one of dense, sparse, or colbert must be True.")
|
||||
return_dense = True # Default to dense if nothing selected
|
||||
|
||||
try:
|
||||
outputs = self.model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
return_dict=True
|
||||
)
|
||||
last_hidden_state = outputs.last_hidden_state # [batch_size, seq_length, hidden_size]
|
||||
except Exception as e:
|
||||
logger.error(f"Model forward pass failed: {e}")
|
||||
raise RuntimeError(f"Forward pass failed: {e}")
|
||||
|
||||
results = {}
|
||||
|
||||
# Dense embeddings (CLS or mean pooling)
|
||||
if return_dense:
|
||||
if self.pooling_method == "cls" or self.sentence_pooling_method == "cls":
|
||||
dense_vecs = last_hidden_state[:, 0]
|
||||
elif self.pooling_method == "mean" or self.sentence_pooling_method == "mean":
|
||||
expanded_mask = attention_mask.unsqueeze(-1).expand(last_hidden_state.shape)
|
||||
masked_embeddings = last_hidden_state * expanded_mask
|
||||
sum_embeddings = torch.sum(masked_embeddings, dim=1)
|
||||
sum_mask = torch.clamp(attention_mask.sum(dim=1, keepdim=True), min=1e-9)
|
||||
dense_vecs = sum_embeddings / sum_mask
|
||||
else:
|
||||
raise ValueError(f"Unknown pooling method: {self.pooling_method}")
|
||||
|
||||
dense_vecs = self.dense_linear(dense_vecs)
|
||||
# Ensure embeddings are float type
|
||||
dense_vecs = dense_vecs.float()
|
||||
if self.normalize_embeddings:
|
||||
dense_vecs = F.normalize(dense_vecs, p=2, dim=-1)
|
||||
results['dense'] = dense_vecs
|
||||
|
||||
# SPLADE+ sparse head
|
||||
if return_sparse:
|
||||
# SPLADE+ uses log1p(ReLU(.))
|
||||
sparse_logits = self.sparse_linear(last_hidden_state).squeeze(-1) # [batch, seq_len]
|
||||
sparse_activations = torch.log1p(F.relu(sparse_logits))
|
||||
|
||||
# Mask out padding tokens
|
||||
sparse_activations = sparse_activations * attention_mask.float()
|
||||
|
||||
# Optionally normalize sparse activations
|
||||
if self.normalize_embeddings:
|
||||
sparse_activations = F.normalize(sparse_activations, p=2, dim=-1)
|
||||
|
||||
results['sparse'] = sparse_activations
|
||||
|
||||
# FLOPS regularization (for loss, not forward, but add as attribute)
|
||||
self.splade_flops = torch.sum(sparse_activations)
|
||||
|
||||
# ColBERT head with late interaction (maxsim)
|
||||
if return_colbert:
|
||||
colbert_vecs = self.colbert_linear(last_hidden_state) # [batch, seq_len, colbert_dim]
|
||||
|
||||
# Mask out padding tokens
|
||||
expanded_mask = attention_mask.unsqueeze(-1).expand(colbert_vecs.shape)
|
||||
colbert_vecs = colbert_vecs * expanded_mask
|
||||
|
||||
if self.normalize_embeddings:
|
||||
colbert_vecs = F.normalize(colbert_vecs, p=2, dim=-1)
|
||||
|
||||
results['colbert'] = colbert_vecs
|
||||
|
||||
if return_dict:
|
||||
return results
|
||||
|
||||
# If only one head requested, return that
|
||||
if return_dense and len(results) == 1:
|
||||
return results['dense']
|
||||
if return_sparse and len(results) == 1:
|
||||
return results['sparse']
|
||||
if return_colbert and len(results) == 1:
|
||||
return results['colbert']
|
||||
|
||||
return results
|
||||
|
||||
def encode(
|
||||
self,
|
||||
sentences: Union[str, List[str]],
|
||||
batch_size: int = 32,
|
||||
max_length: int = 512,
|
||||
convert_to_numpy: bool = True,
|
||||
convert_to_tensor: bool = False,
|
||||
show_progress_bar: bool = False,
|
||||
normalize_embeddings: bool = True,
|
||||
return_dense: bool = True,
|
||||
return_sparse: bool = False,
|
||||
return_colbert: bool = False,
|
||||
device: Optional[str] = None
|
||||
) -> Union[np.ndarray, torch.Tensor, Dict[str, Union[np.ndarray, torch.Tensor]]]:
|
||||
"""
|
||||
Complete encode implementation with comprehensive error handling
|
||||
|
||||
Args:
|
||||
sentences: Input sentences
|
||||
batch_size: Batch size for encoding
|
||||
max_length: Maximum sequence length
|
||||
convert_to_numpy: Convert to numpy arrays
|
||||
convert_to_tensor: Keep as tensors
|
||||
show_progress_bar: Show progress bar
|
||||
normalize_embeddings: Normalize embeddings
|
||||
return_dense/sparse/colbert: Which embeddings to return
|
||||
device: Device to use
|
||||
|
||||
Returns:
|
||||
Embeddings
|
||||
"""
|
||||
# Input validation
|
||||
if isinstance(sentences, str):
|
||||
sentences = [sentences]
|
||||
|
||||
if not sentences:
|
||||
raise ValueError("Empty sentences list provided")
|
||||
|
||||
if batch_size <= 0:
|
||||
raise ValueError("batch_size must be positive")
|
||||
|
||||
if max_length <= 0:
|
||||
raise ValueError("max_length must be positive")
|
||||
|
||||
# Get tokenizer
|
||||
try:
|
||||
tokenizer = self._get_tokenizer()
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to load tokenizer: {e}")
|
||||
|
||||
# Determine target device
|
||||
target_device = device if device else self.device
|
||||
|
||||
# Initialize storage for embeddings
|
||||
all_embeddings = {
|
||||
'dense': [] if return_dense else None,
|
||||
'sparse': [] if return_sparse else None,
|
||||
'colbert': [] if return_colbert else None
|
||||
}
|
||||
|
||||
# Set model to evaluation mode
|
||||
self.eval()
|
||||
|
||||
try:
|
||||
# Import tqdm if show_progress_bar is True
|
||||
if show_progress_bar:
|
||||
try:
|
||||
from tqdm import tqdm
|
||||
iterator = tqdm(
|
||||
range(0, len(sentences), batch_size),
|
||||
desc="Encoding",
|
||||
total=(len(sentences) + batch_size - 1) // batch_size
|
||||
)
|
||||
except ImportError:
|
||||
logger.warning("tqdm not available, disabling progress bar")
|
||||
iterator = range(0, len(sentences), batch_size)
|
||||
else:
|
||||
iterator = range(0, len(sentences), batch_size)
|
||||
|
||||
# Process in batches
|
||||
for i in iterator:
|
||||
batch_sentences = sentences[i:i + batch_size]
|
||||
|
||||
# Tokenize with proper error handling
|
||||
try:
|
||||
encoded = tokenizer(
|
||||
batch_sentences,
|
||||
padding=True,
|
||||
truncation=True,
|
||||
max_length=max_length,
|
||||
return_tensors='pt'
|
||||
)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Tokenization failed for batch {i//batch_size + 1}: {e}")
|
||||
|
||||
# Move to device
|
||||
try:
|
||||
encoded = {k: v.to(target_device) for k, v in encoded.items()}
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to move batch to device {target_device}: {e}")
|
||||
|
||||
# Get embeddings
|
||||
with torch.no_grad():
|
||||
try:
|
||||
outputs = self.forward(
|
||||
**encoded,
|
||||
return_dense=return_dense,
|
||||
return_sparse=return_sparse,
|
||||
return_colbert=return_colbert,
|
||||
return_dict=True
|
||||
)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Forward pass failed for batch {i//batch_size + 1}: {e}")
|
||||
|
||||
# Collect embeddings
|
||||
for key in ['dense', 'sparse', 'colbert']:
|
||||
if all_embeddings[key] is not None and key in outputs:
|
||||
all_embeddings[key].append(outputs[key].cpu()) # type: ignore[arg-type]
|
||||
|
||||
# Concatenate and convert results
|
||||
final_embeddings = {}
|
||||
for key in ['dense', 'sparse', 'colbert']:
|
||||
if all_embeddings[key] is not None and len(all_embeddings[key]) > 0: # type: ignore[arg-type]
|
||||
try:
|
||||
embeddings = torch.cat(all_embeddings[key], dim=0) # type: ignore[arg-type]
|
||||
|
||||
if normalize_embeddings and key == 'dense':
|
||||
embeddings = F.normalize(embeddings, p=2, dim=-1)
|
||||
|
||||
if convert_to_numpy:
|
||||
embeddings = embeddings.numpy()
|
||||
elif not convert_to_tensor:
|
||||
embeddings = embeddings
|
||||
|
||||
final_embeddings[key] = embeddings
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to process {key} embeddings: {e}")
|
||||
raise
|
||||
|
||||
# Return format
|
||||
if len(final_embeddings) == 1:
|
||||
return list(final_embeddings.values())[0]
|
||||
|
||||
if not final_embeddings:
|
||||
raise RuntimeError("No embeddings were generated")
|
||||
|
||||
return final_embeddings
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Encoding failed: {e}")
|
||||
raise
|
||||
|
||||
def save_pretrained(self, save_path: str):
|
||||
"""Save model to directory with error handling, including custom heads and config."""
|
||||
try:
|
||||
import os
|
||||
import json
|
||||
os.makedirs(save_path, exist_ok=True)
|
||||
|
||||
# Save the base model
|
||||
self.model.save_pretrained(save_path)
|
||||
|
||||
# Save additional configuration (including [m3] section)
|
||||
from utils.config_loader import get_config
|
||||
config_dict = get_config().get_section('m3')
|
||||
config = {
|
||||
'use_dense': self.use_dense,
|
||||
'use_sparse': self.use_sparse,
|
||||
'use_colbert': self.use_colbert,
|
||||
'pooling_method': self.pooling_method,
|
||||
'normalize_embeddings': self.normalize_embeddings,
|
||||
'colbert_dim': self.colbert_dim,
|
||||
'sparse_top_k': self.sparse_top_k,
|
||||
'model_name_or_path': self.model_name_or_path,
|
||||
'm3_config': config_dict
|
||||
}
|
||||
with open(f"{save_path}/m3_config.json", 'w') as f:
|
||||
json.dump(config, f, indent=2)
|
||||
|
||||
# Save custom heads if needed
|
||||
if self.use_sparse:
|
||||
torch.save(self.sparse_linear.state_dict(), os.path.join(save_path, 'sparse_linear.pt'))
|
||||
if self.use_colbert:
|
||||
torch.save(self.colbert_linear.state_dict(), os.path.join(save_path, 'colbert_linear.pt'))
|
||||
|
||||
# Save tokenizer if available
|
||||
if self._tokenizer is not None:
|
||||
self._tokenizer.save_pretrained(save_path)
|
||||
|
||||
logger.info(f"Model and config saved to {save_path}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save model: {e}")
|
||||
raise
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, model_path: str, **kwargs):
|
||||
"""Load model from directory, restoring custom heads and config."""
|
||||
import os
|
||||
import json
|
||||
try:
|
||||
# Load config
|
||||
config_path = os.path.join(model_path, "m3_config.json")
|
||||
if os.path.exists(config_path):
|
||||
with open(config_path, 'r') as f:
|
||||
config = json.load(f)
|
||||
else:
|
||||
config = {}
|
||||
|
||||
# Instantiate model
|
||||
model = cls(
|
||||
model_name_or_path=config.get('model_name_or_path', model_path),
|
||||
use_dense=config.get('use_dense', True),
|
||||
use_sparse=config.get('use_sparse', True),
|
||||
use_colbert=config.get('use_colbert', True),
|
||||
pooling_method=config.get('pooling_method', 'cls'),
|
||||
normalize_embeddings=config.get('normalize_embeddings', True),
|
||||
colbert_dim=config.get('colbert_dim', -1),
|
||||
sparse_top_k=config.get('sparse_top_k', 100),
|
||||
**kwargs
|
||||
)
|
||||
|
||||
# Load base model weights
|
||||
model.model = model.model.from_pretrained(model_path)
|
||||
|
||||
# Load custom heads if present
|
||||
if model.use_sparse:
|
||||
sparse_path = os.path.join(model_path, 'sparse_linear.pt')
|
||||
if os.path.exists(sparse_path):
|
||||
model.sparse_linear.load_state_dict(torch.load(sparse_path, map_location='cpu'))
|
||||
if model.use_colbert:
|
||||
colbert_path = os.path.join(model_path, 'colbert_linear.pt')
|
||||
if os.path.exists(colbert_path):
|
||||
model.colbert_linear.load_state_dict(torch.load(colbert_path, map_location='cpu'))
|
||||
|
||||
logger.info(f"Loaded BGE-M3 model from {model_path}")
|
||||
return model
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load model from {model_path}: {e}")
|
||||
raise
|
||||
|
||||
def compute_similarity(
|
||||
self,
|
||||
queries: Union[torch.Tensor, Dict[str, torch.Tensor]],
|
||||
passages: Union[torch.Tensor, Dict[str, torch.Tensor]],
|
||||
mode: str = "dense"
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Compute similarity scores between queries and passages with error handling
|
||||
|
||||
Args:
|
||||
queries: Query embeddings
|
||||
passages: Passage embeddings
|
||||
mode: Similarity mode ('dense', 'sparse', 'colbert')
|
||||
|
||||
Returns:
|
||||
Similarity scores [num_queries, num_passages]
|
||||
"""
|
||||
try:
|
||||
# Extract embeddings based on mode
|
||||
if isinstance(queries, dict):
|
||||
if mode not in queries:
|
||||
raise ValueError(f"Mode '{mode}' not found in query embeddings")
|
||||
query_emb = queries[mode]
|
||||
else:
|
||||
query_emb = queries
|
||||
|
||||
if isinstance(passages, dict):
|
||||
if mode not in passages:
|
||||
raise ValueError(f"Mode '{mode}' not found in passage embeddings")
|
||||
passage_emb = passages[mode]
|
||||
else:
|
||||
passage_emb = passages
|
||||
|
||||
# Validate tensor shapes
|
||||
if query_emb.dim() < 2 or passage_emb.dim() < 2:
|
||||
raise ValueError("Embeddings must be at least 2-dimensional")
|
||||
|
||||
if mode == "dense":
|
||||
# Simple dot product for dense
|
||||
return torch.matmul(query_emb, passage_emb.t())
|
||||
|
||||
elif mode == "sparse":
|
||||
# Dot product for sparse (simplified)
|
||||
return torch.matmul(query_emb, passage_emb.t())
|
||||
|
||||
elif mode == "colbert":
|
||||
# MaxSim for ColBERT
|
||||
if query_emb.dim() != 3 or passage_emb.dim() != 3:
|
||||
raise ValueError("ColBERT embeddings must be 3-dimensional [batch, seq_len, dim]")
|
||||
|
||||
scores = []
|
||||
for q in query_emb:
|
||||
# q: [query_len, dim]
|
||||
# passage_emb: [num_passages, passage_len, dim]
|
||||
passage_scores = []
|
||||
for p in passage_emb:
|
||||
# Compute similarity matrix
|
||||
sim_matrix = torch.matmul(q, p.t()) # [query_len, passage_len]
|
||||
# Max over passage tokens for each query token
|
||||
max_sim_scores = sim_matrix.max(dim=1)[0] # [query_len]
|
||||
# Sum over query tokens
|
||||
score = max_sim_scores.sum()
|
||||
passage_scores.append(score)
|
||||
scores.append(torch.stack(passage_scores))
|
||||
return torch.stack(scores)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unknown similarity mode: {mode}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Similarity computation failed: {e}")
|
||||
raise
|
||||
|
||||
def get_sentence_embedding_dimension(self) -> int:
|
||||
"""Get the dimension of sentence embeddings"""
|
||||
return self.config.hidden_size # type: ignore[attr-defined]
|
||||
|
||||
def get_word_embedding_dimension(self) -> int:
|
||||
"""Get the dimension of word embeddings"""
|
||||
return self.config.hidden_size # type: ignore[attr-defined]
|
||||
483
models/bge_reranker.py
Normal file
483
models/bge_reranker.py
Normal file
@@ -0,0 +1,483 @@
|
||||
"""
|
||||
BGE-Reranker cross-encoder model wrapper
|
||||
"""
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from transformers import AutoModel, AutoTokenizer, AutoConfig, AutoModelForSequenceClassification
|
||||
from typing import Dict, List, Optional, Union, Tuple
|
||||
import numpy as np
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from utils.config_loader import get_config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class RerankerOutput:
|
||||
"""Output class for reranker predictions"""
|
||||
scores: torch.Tensor
|
||||
logits: torch.Tensor
|
||||
embeddings: Optional[torch.Tensor] = None
|
||||
|
||||
|
||||
def load_reranker_model_and_tokenizer(model_name_or_path, cache_dir=None, num_labels=1):
|
||||
"""Load reranker model and tokenizer from HuggingFace or ModelScope based on config.toml"""
|
||||
config = get_config()
|
||||
|
||||
# Try to get model source from config, with fallback to huggingface
|
||||
try:
|
||||
model_source = config.get('model_paths', 'source', 'huggingface')
|
||||
except Exception:
|
||||
# Fallback if config loading fails
|
||||
logger.warning("Failed to load model source from config, defaulting to 'huggingface'")
|
||||
model_source = 'huggingface'
|
||||
|
||||
logger.info(f"Loading reranker model from source: {model_source}")
|
||||
|
||||
if model_source == 'huggingface':
|
||||
from transformers import AutoModel, AutoModelForSequenceClassification, AutoTokenizer, AutoConfig
|
||||
model_config = AutoConfig.from_pretrained(
|
||||
model_name_or_path,
|
||||
num_labels=num_labels,
|
||||
cache_dir=cache_dir
|
||||
)
|
||||
try:
|
||||
model = AutoModelForSequenceClassification.from_pretrained(
|
||||
model_name_or_path,
|
||||
config=model_config,
|
||||
cache_dir=cache_dir
|
||||
)
|
||||
use_base_model = False
|
||||
except:
|
||||
model = AutoModel.from_pretrained(
|
||||
model_name_or_path,
|
||||
config=model_config,
|
||||
cache_dir=cache_dir
|
||||
)
|
||||
use_base_model = True
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_name_or_path,
|
||||
cache_dir=cache_dir
|
||||
)
|
||||
return model, tokenizer, model_config, use_base_model
|
||||
elif model_source == 'modelscope':
|
||||
try:
|
||||
from modelscope.models import Model as MSModel
|
||||
from modelscope.hub.snapshot_download import snapshot_download as ms_snapshot_download
|
||||
ms_snapshot_download(model_id=model_name_or_path, cache_dir=cache_dir)
|
||||
model = MSModel.from_pretrained(model_name_or_path, cache_dir=cache_dir)
|
||||
tokenizer = None # ModelScope tokenizer loading can be added if needed
|
||||
model_config = None
|
||||
use_base_model = False
|
||||
return model, tokenizer, model_config, use_base_model
|
||||
except ImportError:
|
||||
logger.error("modelscope not available. Install it first: pip install modelscope")
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load reranker model from ModelScope: {e}")
|
||||
raise
|
||||
else:
|
||||
logger.error(f"Unknown model source: {model_source}. Supported: 'huggingface', 'modelscope'.")
|
||||
raise RuntimeError(f"Unknown model source: {model_source}")
|
||||
|
||||
|
||||
class BGERerankerModel(nn.Module):
|
||||
"""
|
||||
BGE-Reranker cross-encoder model for passage reranking
|
||||
|
||||
Notes:
|
||||
- All parameters (model_name_or_path, num_labels, cache_dir, use_fp16, gradient_checkpointing, device, etc.) should be passed from config-driven training scripts.
|
||||
- This class does not load config directly; pass config values from your main script for consistency.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name_or_path: str = "BAAI/bge-reranker-base",
|
||||
num_labels: int = 1,
|
||||
cache_dir: Optional[str] = None,
|
||||
use_fp16: bool = False,
|
||||
gradient_checkpointing: bool = False,
|
||||
device: Optional[str] = None
|
||||
):
|
||||
"""
|
||||
Initialize BGE-Reranker model.
|
||||
Args:
|
||||
model_name_or_path: Path or identifier for pretrained model
|
||||
num_labels: Number of output labels
|
||||
cache_dir: Directory for model cache
|
||||
use_fp16: Use mixed precision
|
||||
gradient_checkpointing: Enable gradient checkpointing
|
||||
device: Device to use
|
||||
"""
|
||||
super().__init__()
|
||||
self.model_name_or_path = model_name_or_path
|
||||
self.num_labels = num_labels
|
||||
self.use_fp16 = use_fp16
|
||||
# Load configuration and model
|
||||
try:
|
||||
logger.info(f"Loading BGE-Reranker model from {model_name_or_path} using source from config.toml...")
|
||||
self.model, self.tokenizer, self.config, self.use_base_model = load_reranker_model_and_tokenizer(
|
||||
model_name_or_path,
|
||||
cache_dir=cache_dir,
|
||||
num_labels=num_labels
|
||||
)
|
||||
logger.info("BGE-Reranker model loaded successfully")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load BGE-Reranker model: {e}")
|
||||
raise RuntimeError(f"Model initialization failed: {e}")
|
||||
# Add classification head
|
||||
if self.use_base_model:
|
||||
self.classifier = nn.Linear(self.config.hidden_size, num_labels) # type: ignore[attr-defined]
|
||||
self.dropout = nn.Dropout(
|
||||
self.config.hidden_dropout_prob if hasattr(self.config, 'hidden_dropout_prob') else 0.1) # type: ignore[attr-defined]
|
||||
# Enable gradient checkpointing
|
||||
if gradient_checkpointing:
|
||||
try:
|
||||
self.model.gradient_checkpointing_enable()
|
||||
except Exception:
|
||||
logger.warning("Gradient checkpointing requested but not supported by this model.")
|
||||
# Set device
|
||||
if device:
|
||||
try:
|
||||
self.device = torch.device(device)
|
||||
self.to(self.device)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to move model to {device}: {e}, using CPU")
|
||||
self.device = torch.device('cpu')
|
||||
self.to(self.device)
|
||||
else:
|
||||
self.device = next(self.parameters()).device
|
||||
# Mixed precision
|
||||
if self.use_fp16:
|
||||
try:
|
||||
self.half()
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to set model to half precision: {e}")
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
attention_mask: torch.Tensor,
|
||||
token_type_ids: Optional[torch.Tensor] = None,
|
||||
labels: Optional[torch.Tensor] = None,
|
||||
return_dict: bool = True
|
||||
) -> Union[torch.Tensor, RerankerOutput]:
|
||||
"""
|
||||
Forward pass for cross-encoder reranking
|
||||
|
||||
Args:
|
||||
input_ids: Token IDs [batch_size, seq_length]
|
||||
attention_mask: Attention mask [batch_size, seq_length]
|
||||
token_type_ids: Token type IDs (optional)
|
||||
labels: Labels for training (optional)
|
||||
return_dict: Return RerankerOutput object
|
||||
|
||||
Returns:
|
||||
Scores or RerankerOutput
|
||||
"""
|
||||
if self.use_base_model:
|
||||
# Use base model + custom head
|
||||
outputs = self.model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
return_dict=True
|
||||
)
|
||||
|
||||
# Get pooled output (CLS token)
|
||||
if hasattr(outputs, 'pooler_output') and outputs.pooler_output is not None:
|
||||
pooled_output = outputs.pooler_output
|
||||
else:
|
||||
# Manual pooling if pooler not available
|
||||
last_hidden_state = outputs.last_hidden_state
|
||||
pooled_output = last_hidden_state[:, 0]
|
||||
|
||||
# Apply dropout and classifier
|
||||
pooled_output = self.dropout(pooled_output)
|
||||
logits = self.classifier(pooled_output)
|
||||
|
||||
embeddings = pooled_output
|
||||
else:
|
||||
# Use sequence classification model
|
||||
outputs = self.model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
labels=labels,
|
||||
return_dict=True
|
||||
)
|
||||
logits = outputs.logits
|
||||
embeddings = None
|
||||
|
||||
# Convert logits to scores (sigmoid for binary classification)
|
||||
if self.num_labels == 1:
|
||||
scores = torch.sigmoid(logits.squeeze(-1))
|
||||
else:
|
||||
scores = F.softmax(logits, dim=-1)
|
||||
|
||||
if not return_dict:
|
||||
return scores
|
||||
|
||||
return RerankerOutput(
|
||||
scores=scores,
|
||||
logits=logits,
|
||||
embeddings=embeddings
|
||||
)
|
||||
|
||||
def predict(
|
||||
self,
|
||||
queries: Union[str, List[str]],
|
||||
passages: Union[str, List[str], List[List[str]]],
|
||||
batch_size: int = 32,
|
||||
max_length: int = 512,
|
||||
show_progress_bar: bool = False,
|
||||
convert_to_numpy: bool = True,
|
||||
normalize_scores: bool = False
|
||||
) -> Union[np.ndarray, torch.Tensor, List[float]]:
|
||||
"""
|
||||
Predict reranking scores for query-passage pairs
|
||||
|
||||
Args:
|
||||
queries: Query or list of queries
|
||||
passages: Passage(s) or list of passages per query
|
||||
batch_size: Batch size for inference
|
||||
max_length: Maximum sequence length
|
||||
show_progress_bar: Show progress bar
|
||||
convert_to_numpy: Convert to numpy array
|
||||
normalize_scores: Normalize scores to [0, 1]
|
||||
|
||||
Returns:
|
||||
Reranking scores
|
||||
"""
|
||||
# Handle input formats
|
||||
if isinstance(queries, str):
|
||||
queries = [queries]
|
||||
|
||||
if isinstance(passages, str):
|
||||
passages = [[passages]]
|
||||
elif isinstance(passages[0], str):
|
||||
passages = [passages] # type: ignore[assignment]
|
||||
|
||||
# Ensure queries and passages match
|
||||
if len(queries) == 1 and len(passages) > 1:
|
||||
queries = queries * len(passages)
|
||||
|
||||
# Create pairs
|
||||
pairs = []
|
||||
for query, passage_list in zip(queries, passages):
|
||||
if isinstance(passage_list, str):
|
||||
passage_list = [passage_list]
|
||||
for passage in passage_list:
|
||||
pairs.append((query, passage))
|
||||
|
||||
# Load tokenizer
|
||||
tokenizer = AutoTokenizer.from_pretrained(self.model_name_or_path)
|
||||
|
||||
# Score in batches
|
||||
all_scores = []
|
||||
|
||||
if show_progress_bar:
|
||||
from tqdm import tqdm
|
||||
pbar = tqdm(total=len(pairs), desc="Scoring")
|
||||
|
||||
for i in range(0, len(pairs), batch_size):
|
||||
batch_pairs = pairs[i:i + batch_size]
|
||||
|
||||
# Tokenize
|
||||
batch_queries = [pair[0] for pair in batch_pairs]
|
||||
batch_passages = [pair[1] for pair in batch_pairs]
|
||||
|
||||
encoded = tokenizer(
|
||||
batch_queries,
|
||||
batch_passages,
|
||||
padding=True,
|
||||
truncation=True,
|
||||
max_length=max_length,
|
||||
return_tensors='pt'
|
||||
)
|
||||
|
||||
# Move to device
|
||||
encoded = {k: v.to(self.device) for k, v in encoded.items()}
|
||||
|
||||
# Get scores
|
||||
with torch.no_grad():
|
||||
outputs = self.forward(**encoded, return_dict=True)
|
||||
scores = outputs.scores # type: ignore[attr-defined]
|
||||
|
||||
all_scores.append(scores)
|
||||
|
||||
if show_progress_bar:
|
||||
pbar.update(len(batch_pairs))
|
||||
|
||||
if show_progress_bar:
|
||||
pbar.close()
|
||||
|
||||
# Concatenate scores
|
||||
all_scores = torch.cat(all_scores, dim=0)
|
||||
|
||||
# Normalize if requested
|
||||
if normalize_scores:
|
||||
all_scores = (all_scores - all_scores.min()) / (all_scores.max() - all_scores.min() + 1e-8)
|
||||
|
||||
# Convert format
|
||||
if convert_to_numpy:
|
||||
all_scores = all_scores.cpu().numpy()
|
||||
elif not isinstance(all_scores, torch.Tensor):
|
||||
all_scores = all_scores.cpu()
|
||||
|
||||
# Reshape if single query
|
||||
if len(queries) == 1 and len(passages) == 1:
|
||||
return all_scores.item() if convert_to_numpy else all_scores # type: ignore[attr-defined]
|
||||
|
||||
return all_scores
|
||||
|
||||
def rerank(
|
||||
self,
|
||||
query: str,
|
||||
passages: List[str],
|
||||
top_k: Optional[int] = None,
|
||||
return_scores: bool = False,
|
||||
batch_size: int = 32,
|
||||
max_length: int = 512
|
||||
) -> Union[List[str], Tuple[List[str], List[float]]]:
|
||||
"""
|
||||
Rerank passages for a query
|
||||
|
||||
Args:
|
||||
query: Query string
|
||||
passages: List of passages to rerank
|
||||
top_k: Return only top-k passages
|
||||
return_scores: Also return scores
|
||||
batch_size: Batch size for scoring
|
||||
max_length: Maximum sequence length
|
||||
|
||||
Returns:
|
||||
Reranked passages (and optionally scores)
|
||||
"""
|
||||
# Get scores
|
||||
scores = self.predict(
|
||||
query,
|
||||
passages,
|
||||
batch_size=batch_size,
|
||||
max_length=max_length,
|
||||
convert_to_numpy=True
|
||||
)
|
||||
|
||||
# Sort by scores (descending)
|
||||
sorted_indices = np.argsort(scores)[::-1]
|
||||
|
||||
# Apply top_k if specified
|
||||
if top_k is not None:
|
||||
sorted_indices = sorted_indices[:top_k]
|
||||
|
||||
# Get reranked passages
|
||||
reranked_passages = [passages[i] for i in sorted_indices]
|
||||
|
||||
if return_scores:
|
||||
reranked_scores = [float(scores[i]) for i in sorted_indices]
|
||||
return reranked_passages, reranked_scores
|
||||
|
||||
return reranked_passages
|
||||
|
||||
def save_pretrained(self, save_path: str):
|
||||
"""Save model to directory, including custom heads and config."""
|
||||
try:
|
||||
import os
|
||||
import json
|
||||
os.makedirs(save_path, exist_ok=True)
|
||||
|
||||
# Save the base model
|
||||
self.model.save_pretrained(save_path)
|
||||
|
||||
# Save additional configuration (including [reranker] section)
|
||||
from utils.config_loader import get_config
|
||||
config_dict = get_config().get_section('reranker')
|
||||
config = {
|
||||
'use_fp16': self.use_fp16 if hasattr(self, 'use_fp16') else False,
|
||||
'num_labels': getattr(self, 'num_labels', 1),
|
||||
'reranker_config': config_dict
|
||||
}
|
||||
with open(f"{save_path}/reranker_config.json", 'w') as f:
|
||||
json.dump(config, f, indent=2)
|
||||
|
||||
# Save custom heads if present
|
||||
if hasattr(self, 'classifier'):
|
||||
torch.save(self.classifier.state_dict(), os.path.join(save_path, 'classifier.pt'))
|
||||
if hasattr(self, 'dropout'):
|
||||
torch.save(self.dropout.state_dict(), os.path.join(save_path, 'dropout.pt'))
|
||||
|
||||
# Save tokenizer if available
|
||||
if hasattr(self, '_tokenizer') and self._tokenizer is not None:
|
||||
self._tokenizer.save_pretrained(save_path) # type: ignore[attr-defined]
|
||||
|
||||
logger.info(f"Reranker model and config saved to {save_path}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save reranker model: {e}")
|
||||
raise
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, model_path: str, **kwargs):
|
||||
"""Load model from directory, restoring custom heads and config."""
|
||||
import os
|
||||
import json
|
||||
try:
|
||||
# Load config
|
||||
config_path = os.path.join(model_path, "reranker_config.json")
|
||||
if os.path.exists(config_path):
|
||||
with open(config_path, 'r') as f:
|
||||
config = json.load(f)
|
||||
else:
|
||||
config = {}
|
||||
|
||||
# Instantiate model
|
||||
model = cls(
|
||||
model_name_or_path=model_path,
|
||||
num_labels=config.get('num_labels', 1),
|
||||
use_fp16=config.get('use_fp16', False),
|
||||
**kwargs
|
||||
)
|
||||
|
||||
# Load base model weights
|
||||
model.model = model.model.from_pretrained(model_path)
|
||||
|
||||
# Load custom heads if present
|
||||
classifier_path = os.path.join(model_path, 'classifier.pt')
|
||||
if hasattr(model, 'classifier') and os.path.exists(classifier_path):
|
||||
model.classifier.load_state_dict(torch.load(classifier_path, map_location='cpu'))
|
||||
dropout_path = os.path.join(model_path, 'dropout.pt')
|
||||
if hasattr(model, 'dropout') and os.path.exists(dropout_path):
|
||||
model.dropout.load_state_dict(torch.load(dropout_path, map_location='cpu'))
|
||||
|
||||
logger.info(f"Loaded BGE-Reranker model from {model_path}")
|
||||
return model
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load reranker model from {model_path}: {e}")
|
||||
raise
|
||||
|
||||
def compute_loss(
|
||||
self,
|
||||
scores: torch.Tensor,
|
||||
labels: torch.Tensor,
|
||||
loss_type: str = "binary_cross_entropy"
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Compute loss for training
|
||||
|
||||
Args:
|
||||
scores: Model scores
|
||||
labels: Ground truth labels
|
||||
loss_type: Type of loss function
|
||||
|
||||
Returns:
|
||||
Loss value
|
||||
"""
|
||||
if loss_type == "binary_cross_entropy":
|
||||
return F.binary_cross_entropy(scores, labels.float())
|
||||
elif loss_type == "cross_entropy":
|
||||
return F.cross_entropy(scores, labels.long())
|
||||
elif loss_type == "mse":
|
||||
return F.mse_loss(scores, labels.float())
|
||||
else:
|
||||
raise ValueError(f"Unknown loss type: {loss_type}")
|
||||
507
models/losses.py
Normal file
507
models/losses.py
Normal file
@@ -0,0 +1,507 @@
|
||||
"""
|
||||
Loss functions for BGE fine-tuning
|
||||
"""
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
import numpy as np
|
||||
|
||||
|
||||
class InfoNCELoss(nn.Module):
|
||||
"""
|
||||
InfoNCE loss for contrastive learning
|
||||
As used in BGE-M3 paper with temperature τ=0.02
|
||||
|
||||
Notes:
|
||||
- All parameters (temperature, reduction, etc.) should be passed from config-driven training scripts.
|
||||
- This class does not load config directly; pass config values from your main script for consistency.
|
||||
"""
|
||||
|
||||
def __init__(self, temperature: float = 0.02, reduction: str = 'mean'):
|
||||
super().__init__()
|
||||
self.temperature = temperature
|
||||
self.reduction = reduction
|
||||
self.cross_entropy = nn.CrossEntropyLoss(reduction=reduction)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
query_embeddings: torch.Tensor,
|
||||
positive_embeddings: torch.Tensor,
|
||||
negative_embeddings: Optional[torch.Tensor] = None
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Compute InfoNCE loss with proper normalization and numerical stability.
|
||||
|
||||
Args:
|
||||
query_embeddings: [batch_size, embedding_dim]
|
||||
positive_embeddings: [batch_size, embedding_dim]
|
||||
negative_embeddings: [batch_size, num_negatives, embedding_dim] or None
|
||||
"""
|
||||
batch_size = query_embeddings.size(0)
|
||||
device = query_embeddings.device
|
||||
|
||||
# Normalize embeddings (L2 normalization)
|
||||
query_embeddings = F.normalize(query_embeddings, p=2, dim=1)
|
||||
positive_embeddings = F.normalize(positive_embeddings, p=2, dim=1)
|
||||
|
||||
# Compute positive scores (query-positive similarity)
|
||||
# Shape: [batch_size]
|
||||
pos_scores = torch.sum(query_embeddings * positive_embeddings, dim=1)
|
||||
pos_logits = pos_scores / self.temperature
|
||||
|
||||
# Compute negative scores if provided
|
||||
if negative_embeddings is not None and negative_embeddings.numel() > 0:
|
||||
# Normalize negative embeddings
|
||||
negative_embeddings = F.normalize(negative_embeddings, p=2, dim=-1)
|
||||
# Compute negative scores
|
||||
# Shape: [batch_size, num_negatives]
|
||||
neg_scores = torch.matmul(
|
||||
query_embeddings.unsqueeze(1),
|
||||
negative_embeddings.transpose(1, 2)
|
||||
).squeeze(1)
|
||||
neg_logits = neg_scores / self.temperature
|
||||
|
||||
# Concatenate positive and negative logits
|
||||
# Shape: [batch_size, 1 + num_negatives]
|
||||
logits = torch.cat([pos_logits.unsqueeze(1), neg_logits], dim=1)
|
||||
|
||||
# Labels: positive is always at index 0
|
||||
labels = torch.zeros(batch_size, device=device, dtype=torch.long)
|
||||
else:
|
||||
# Use in-batch negatives as fallback
|
||||
# Compute similarity matrix for in-batch negatives
|
||||
similarity_matrix = torch.matmul(query_embeddings, positive_embeddings.t())
|
||||
logits = similarity_matrix / self.temperature
|
||||
|
||||
# Labels: positive pairs are on the diagonal
|
||||
labels = torch.arange(batch_size, device=device, dtype=torch.long)
|
||||
|
||||
# Compute cross-entropy loss
|
||||
loss = self.cross_entropy(logits, labels)
|
||||
|
||||
return loss
|
||||
|
||||
|
||||
class M3KnowledgeDistillationLoss(nn.Module):
|
||||
"""
|
||||
Knowledge distillation loss for BGE-M3's multi-functionality training
|
||||
Distills knowledge from dense, sparse, and multi-vector representations
|
||||
|
||||
Notes:
|
||||
- All parameters (temperature, alpha, distill_types, etc.) should be passed from config-driven training scripts.
|
||||
- This class does not load config directly; pass config values from your main script for consistency.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
temperature: float = 4.0,
|
||||
alpha: float = 0.5,
|
||||
distill_types: List[str] = ['dense', 'sparse', 'colbert']
|
||||
):
|
||||
super().__init__()
|
||||
self.temperature = temperature
|
||||
self.alpha = alpha
|
||||
self.distill_types = distill_types
|
||||
self.kl_div = nn.KLDivLoss(reduction='batchmean')
|
||||
|
||||
def forward(
|
||||
self,
|
||||
student_outputs: Dict[str, torch.Tensor],
|
||||
teacher_outputs: Dict[str, torch.Tensor],
|
||||
labels: Optional[torch.Tensor] = None
|
||||
) -> Tuple[torch.Tensor, Dict[str, float]]:
|
||||
"""
|
||||
Args:
|
||||
student_outputs: Dictionary with 'dense', 'sparse', 'colbert' outputs
|
||||
teacher_outputs: Dictionary with teacher model outputs
|
||||
labels: Optional labels for supervised loss
|
||||
|
||||
Returns:
|
||||
total_loss: Combined distillation loss
|
||||
loss_dict: Dictionary of individual losses
|
||||
"""
|
||||
total_loss = torch.tensor(0.0, device=next(iter(student_outputs.values())).device if student_outputs else 'cpu')
|
||||
loss_dict = {}
|
||||
|
||||
# Distillation loss for each representation type
|
||||
for repr_type in self.distill_types:
|
||||
if repr_type in student_outputs and repr_type in teacher_outputs:
|
||||
student_logits = student_outputs[repr_type]
|
||||
teacher_logits = teacher_outputs[repr_type]
|
||||
|
||||
# Apply temperature scaling
|
||||
student_log_probs = F.log_softmax(student_logits / self.temperature, dim=-1)
|
||||
teacher_probs = F.softmax(teacher_logits / self.temperature, dim=-1)
|
||||
|
||||
# KL divergence loss
|
||||
kl_loss = self.kl_div(student_log_probs, teacher_probs) * (self.temperature ** 2)
|
||||
total_loss = total_loss + self.alpha * kl_loss
|
||||
loss_dict[f'{repr_type}_distill_loss'] = kl_loss.detach().cpu().item()
|
||||
|
||||
# Add supervised loss if labels provided
|
||||
if labels is not None and 'dense' in student_outputs:
|
||||
ce_loss = F.cross_entropy(student_outputs['dense'], labels)
|
||||
total_loss = total_loss + (1 - self.alpha) * ce_loss
|
||||
loss_dict['supervised_loss'] = ce_loss.detach().cpu().item()
|
||||
|
||||
loss_dict['total_loss'] = total_loss.detach().cpu().item()
|
||||
return total_loss, loss_dict
|
||||
|
||||
|
||||
class ListwiseCrossEntropyLoss(nn.Module):
|
||||
"""
|
||||
Listwise cross-entropy loss for reranker training
|
||||
|
||||
Notes:
|
||||
- All parameters (temperature, etc.) should be passed from config-driven training scripts.
|
||||
- This class does not load config directly; pass config values from your main script for consistency.
|
||||
"""
|
||||
|
||||
def __init__(self, temperature: float = 1.0):
|
||||
super().__init__()
|
||||
self.temperature = temperature
|
||||
|
||||
def forward(self, scores: torch.Tensor, labels: torch.Tensor, groups: List[int]) -> torch.Tensor:
|
||||
"""
|
||||
Compute listwise cross-entropy loss with proper normalization.
|
||||
|
||||
Args:
|
||||
scores: Relevance scores [total_passages]
|
||||
labels: Binary relevance labels [total_passages]
|
||||
groups: List of group sizes (passages per query)
|
||||
"""
|
||||
total_loss = torch.tensor(0.0, device=scores.device, dtype=scores.dtype)
|
||||
offset = 0
|
||||
|
||||
for group_size in groups:
|
||||
if group_size == 0:
|
||||
continue
|
||||
|
||||
# Get scores and labels for this group
|
||||
group_scores = scores[offset:offset + group_size]
|
||||
group_labels = labels[offset:offset + group_size]
|
||||
|
||||
# Skip if no positive labels
|
||||
if group_labels.sum() == 0:
|
||||
offset += group_size
|
||||
continue
|
||||
|
||||
# Apply temperature scaling
|
||||
scaled_scores = group_scores / self.temperature
|
||||
|
||||
# Apply log_softmax for numerical stability
|
||||
log_probs = F.log_softmax(scaled_scores, dim=0)
|
||||
|
||||
# Normalize labels to create a probability distribution
|
||||
label_probs = group_labels.float() / group_labels.sum()
|
||||
|
||||
# Compute cross-entropy: -sum(p_true * log(p_pred))
|
||||
loss = -(label_probs * log_probs).sum()
|
||||
|
||||
total_loss = total_loss + loss
|
||||
offset += group_size
|
||||
|
||||
# Average over number of groups
|
||||
num_valid_groups = sum(1 for g in groups if g > 0)
|
||||
if num_valid_groups > 0:
|
||||
return total_loss / num_valid_groups
|
||||
else:
|
||||
return total_loss
|
||||
|
||||
|
||||
class PairwiseMarginLoss(nn.Module):
|
||||
"""
|
||||
Pairwise margin ranking loss for reranker
|
||||
|
||||
Notes:
|
||||
- All parameters (margin, etc.) should be passed from config-driven training scripts.
|
||||
- This class does not load config directly; pass config values from your main script for consistency.
|
||||
"""
|
||||
|
||||
def __init__(self, margin: float = 1.0):
|
||||
super().__init__()
|
||||
self.margin = margin
|
||||
self.ranking_loss = nn.MarginRankingLoss(margin=margin)
|
||||
|
||||
def forward(self, scores: torch.Tensor, labels: torch.Tensor, groups: List[int]) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
scores: Relevance scores [total_passages]
|
||||
labels: Binary relevance labels [total_passages]
|
||||
groups: List of group sizes
|
||||
"""
|
||||
total_loss = torch.tensor(0.0, device=scores.device)
|
||||
num_pairs = 0
|
||||
offset = 0
|
||||
|
||||
for group_size in groups:
|
||||
group_scores = scores[offset:offset + group_size]
|
||||
group_labels = labels[offset:offset + group_size]
|
||||
|
||||
# Find positive and negative indices
|
||||
pos_indices = torch.where(group_labels == 1)[0]
|
||||
neg_indices = torch.where(group_labels == 0)[0]
|
||||
|
||||
if len(pos_indices) > 0 and len(neg_indices) > 0:
|
||||
# Create all positive-negative pairs
|
||||
for pos_idx in pos_indices:
|
||||
for neg_idx in neg_indices:
|
||||
pos_score = group_scores[pos_idx].unsqueeze(0)
|
||||
neg_score = group_scores[neg_idx].unsqueeze(0)
|
||||
target = torch.ones(1, device=scores.device)
|
||||
|
||||
loss = self.ranking_loss(pos_score, neg_score, target)
|
||||
total_loss = total_loss + loss
|
||||
num_pairs += 1
|
||||
|
||||
offset += group_size
|
||||
|
||||
return total_loss / max(num_pairs, 1)
|
||||
|
||||
|
||||
class DynamicListwiseDistillationLoss(nn.Module):
|
||||
"""
|
||||
Dynamic listwise distillation loss for RocketQAv2-style joint training
|
||||
|
||||
Notes:
|
||||
- All parameters (temperature, dynamic_weight, etc.) should be passed from config-driven training scripts.
|
||||
- This class does not load config directly; pass config values from your main script for consistency.
|
||||
"""
|
||||
|
||||
def __init__(self, temperature: float = 3.0, dynamic_weight: bool = True):
|
||||
super().__init__()
|
||||
self.temperature = temperature
|
||||
self.dynamic_weight = dynamic_weight
|
||||
self.kl_div = nn.KLDivLoss(reduction='none')
|
||||
|
||||
def forward(
|
||||
self,
|
||||
student_scores: torch.Tensor,
|
||||
teacher_scores: torch.Tensor,
|
||||
groups: List[int],
|
||||
confidence_scores: Optional[torch.Tensor] = None
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
student_scores: Student model scores [total_passages]
|
||||
teacher_scores: Teacher model scores [total_passages]
|
||||
groups: List of group sizes
|
||||
confidence_scores: Optional confidence for dynamic weighting
|
||||
"""
|
||||
total_loss = torch.tensor(0.0, device=student_scores.device)
|
||||
offset = 0
|
||||
|
||||
for i, group_size in enumerate(groups):
|
||||
student_group = student_scores[offset:offset + group_size] / self.temperature
|
||||
teacher_group = teacher_scores[offset:offset + group_size] / self.temperature
|
||||
|
||||
# Convert to distributions
|
||||
student_probs = F.log_softmax(student_group, dim=0)
|
||||
teacher_probs = F.softmax(teacher_group, dim=0)
|
||||
|
||||
# Compute KL divergence
|
||||
kl_loss = self.kl_div(student_probs, teacher_probs).sum()
|
||||
|
||||
# Apply dynamic weighting based on confidence
|
||||
if self.dynamic_weight and confidence_scores is not None:
|
||||
weight = confidence_scores[i]
|
||||
kl_loss = kl_loss * weight
|
||||
|
||||
total_loss = total_loss + kl_loss * (self.temperature ** 2)
|
||||
offset += group_size
|
||||
|
||||
return total_loss / len(groups)
|
||||
|
||||
|
||||
class MultiTaskLoss(nn.Module):
|
||||
"""
|
||||
Multi-task loss for combining multiple objectives
|
||||
|
||||
Notes:
|
||||
- All parameters (loss_weights, adaptive_weights, etc.) should be passed from config-driven training scripts.
|
||||
- This class does not load config directly; pass config values from your main script for consistency.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
loss_weights: Optional[Dict[str, float]] = None,
|
||||
adaptive_weights: bool = False
|
||||
):
|
||||
super().__init__()
|
||||
self.loss_weights = loss_weights or {
|
||||
'dense': 1.0,
|
||||
'sparse': 0.3,
|
||||
'colbert': 0.5,
|
||||
'distillation': 1.0
|
||||
}
|
||||
self.adaptive_weights = adaptive_weights
|
||||
|
||||
if adaptive_weights:
|
||||
# Learnable loss weights (uncertainty weighting)
|
||||
self.log_vars = nn.Parameter(torch.zeros(len(self.loss_weights)))
|
||||
|
||||
def forward(self, losses: Dict[str, torch.Tensor]) -> Tuple[torch.Tensor, Dict[str, float]]:
|
||||
"""
|
||||
Args:
|
||||
losses: Dictionary of individual losses
|
||||
|
||||
Returns:
|
||||
total_loss: Combined loss
|
||||
loss_dict: Dictionary of individual loss values
|
||||
"""
|
||||
total_loss = torch.tensor(0.0, device=next(iter(losses.values())).device if losses else 'cpu')
|
||||
loss_dict = {}
|
||||
|
||||
if self.adaptive_weights:
|
||||
# Uncertainty-based weighting
|
||||
for i, (loss_name, loss_value) in enumerate(losses.items()):
|
||||
if loss_name in self.loss_weights:
|
||||
precision = torch.exp(-self.log_vars[i])
|
||||
weighted_loss = precision * loss_value + self.log_vars[i]
|
||||
total_loss = total_loss + weighted_loss
|
||||
loss_dict[f'{loss_name}_loss'] = loss_value.detach().cpu().item() if isinstance(loss_value, torch.Tensor) else float(loss_value)
|
||||
loss_dict[f'{loss_name}_weight'] = precision.detach().cpu().item() if isinstance(precision, torch.Tensor) else float(precision)
|
||||
else:
|
||||
# Fixed weighting
|
||||
for loss_name, loss_value in losses.items():
|
||||
if loss_name in self.loss_weights:
|
||||
weight = self.loss_weights[loss_name]
|
||||
weighted_loss = weight * loss_value
|
||||
total_loss = total_loss + weighted_loss
|
||||
loss_dict[f'{loss_name}_loss'] = loss_value.detach().cpu().item() if isinstance(loss_value, torch.Tensor) else float(loss_value)
|
||||
loss_dict[f'{loss_name}_weight'] = weight
|
||||
|
||||
loss_dict['total_loss'] = total_loss.detach().cpu().item() if isinstance(total_loss, torch.Tensor) else float(total_loss)
|
||||
return total_loss, loss_dict
|
||||
|
||||
|
||||
class SparseLoss(nn.Module):
|
||||
"""
|
||||
SPLADE+ sparse loss with log1p(ReLU(.)) and FLOPS regularization.
|
||||
For BGE-M3's token-position based sparse representations.
|
||||
Args:
|
||||
flops_loss_weight: Weight for FLOPS regularization
|
||||
sparse_reg_weight: Weight for sparsity regularization
|
||||
temperature: Temperature for similarity computation
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
flops_loss_weight: float = 1e-3,
|
||||
sparse_reg_weight: float = 1e-4,
|
||||
temperature: float = 0.02
|
||||
):
|
||||
super().__init__()
|
||||
self.flops_loss_weight = flops_loss_weight
|
||||
self.sparse_reg_weight = sparse_reg_weight
|
||||
self.temperature = temperature
|
||||
self.cross_entropy = nn.CrossEntropyLoss()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
query_sparse: torch.Tensor,
|
||||
doc_sparse: torch.Tensor,
|
||||
labels: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, Dict[str, float]]:
|
||||
"""
|
||||
Args:
|
||||
query_sparse: [batch, query_seq_len] - Query sparse embeddings (token-position based)
|
||||
doc_sparse: [batch, doc_seq_len] - Document sparse embeddings (token-position based)
|
||||
labels: [batch] - Labels for contrastive learning
|
||||
Returns:
|
||||
Tuple of (loss, loss_dict)
|
||||
"""
|
||||
batch_size = query_sparse.size(0)
|
||||
device = query_sparse.device
|
||||
|
||||
# SPLADE+ uses log1p(ReLU(.))
|
||||
query_splade = torch.log1p(F.relu(query_sparse))
|
||||
doc_splade = torch.log1p(F.relu(doc_sparse))
|
||||
|
||||
# For BGE-M3's token-position sparse embeddings, compute similarity using aggregated scores
|
||||
# Since we have different sequence lengths, we need to aggregate to comparable representations
|
||||
|
||||
# Sum pooling to get document-level sparse scores
|
||||
query_scores = query_splade.sum(dim=1, keepdim=True) # [batch, 1]
|
||||
doc_scores = doc_splade.sum(dim=1, keepdim=True) # [batch, 1]
|
||||
|
||||
# Compute similarity matrix for contrastive learning
|
||||
# Use broadcasting to create [batch, batch] similarity matrix
|
||||
similarity_matrix = query_scores * doc_scores.t() # [batch, batch]
|
||||
|
||||
# Alternative: use cosine similarity approach
|
||||
# Normalize the aggregated scores
|
||||
query_norm = F.normalize(query_scores, p=2, dim=1)
|
||||
doc_norm = F.normalize(doc_scores, p=2, dim=1)
|
||||
similarity_matrix = torch.matmul(query_norm, doc_norm.t())
|
||||
|
||||
# Scale by temperature for contrastive loss
|
||||
logits = similarity_matrix / self.temperature
|
||||
|
||||
# Labels for in-batch contrastive learning (diagonal elements are positive pairs)
|
||||
contrastive_labels = torch.arange(batch_size, device=device, dtype=torch.long)
|
||||
|
||||
# Contrastive loss
|
||||
main_loss = self.cross_entropy(logits, contrastive_labels)
|
||||
|
||||
# FLOPS regularization (encourage sparsity)
|
||||
flops = (query_splade.sum(dim=-1) + doc_splade.sum(dim=-1)).mean()
|
||||
|
||||
# L1 sparsity regularization
|
||||
sparsity_loss = (query_splade.sum() + doc_splade.sum()) / (query_splade.numel() + doc_splade.numel())
|
||||
|
||||
# Combine losses
|
||||
total_loss = main_loss + self.flops_loss_weight * flops + self.sparse_reg_weight * sparsity_loss
|
||||
|
||||
# Ensure all losses are tensors
|
||||
if not isinstance(total_loss, torch.Tensor):
|
||||
total_loss = torch.tensor(total_loss, device=device)
|
||||
if not isinstance(main_loss, torch.Tensor):
|
||||
main_loss = torch.tensor(main_loss, device=device)
|
||||
if not isinstance(flops, torch.Tensor):
|
||||
flops = torch.tensor(flops, device=device)
|
||||
if not isinstance(sparsity_loss, torch.Tensor):
|
||||
sparsity_loss = torch.tensor(sparsity_loss, device=device)
|
||||
|
||||
loss_dict = {
|
||||
'main_loss': main_loss.detach().cpu().item(),
|
||||
'flops': flops.detach().cpu().item(),
|
||||
'sparsity_loss': sparsity_loss.detach().cpu().item(),
|
||||
'total_loss': total_loss.detach().cpu().item()
|
||||
}
|
||||
|
||||
return total_loss, loss_dict
|
||||
|
||||
|
||||
class ColBERTLoss(nn.Module):
|
||||
"""
|
||||
ColBERT late interaction (maxsim) loss.
|
||||
Args:
|
||||
temperature: Softmax temperature
|
||||
"""
|
||||
def __init__(self, temperature: float = 0.02):
|
||||
super().__init__()
|
||||
self.temperature = temperature
|
||||
def forward(
|
||||
self,
|
||||
query_embeds: torch.Tensor,
|
||||
doc_embeds: torch.Tensor,
|
||||
labels: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
query_embeds: [batch, seq_len, dim]
|
||||
doc_embeds: [batch, seq_len, dim]
|
||||
labels: [batch]
|
||||
Returns:
|
||||
Loss value
|
||||
"""
|
||||
# Late interaction: maxsim over doc tokens for each query token
|
||||
# [batch, query_len, doc_len]
|
||||
sim_matrix = torch.einsum('bqd,bkd->bqk', query_embeds, doc_embeds) / self.temperature
|
||||
maxsim = sim_matrix.max(dim=2)[0] # [batch, query_len]
|
||||
# Aggregate over query tokens (mean)
|
||||
scores = maxsim.mean(dim=1) # [batch]
|
||||
# Binary cross-entropy loss
|
||||
loss = F.binary_cross_entropy_with_logits(scores, labels.float())
|
||||
return loss
|
||||
Reference in New Issue
Block a user