769 lines
30 KiB
Python
769 lines
30 KiB
Python
"""
|
|
BGE-M3 model wrapper for multi-functionality embedding
|
|
"""
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from transformers import AutoModel, AutoTokenizer, AutoConfig
|
|
from typing import Dict, List, Optional, Union, Tuple
|
|
import numpy as np
|
|
import logging
|
|
from utils.config_loader import get_config
|
|
import os
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def load_model_and_tokenizer(model_name_or_path, cache_dir=None):
|
|
"""
|
|
Load model and tokenizer with comprehensive error handling and fallback strategies
|
|
|
|
Args:
|
|
model_name_or_path: Model identifier or local path
|
|
cache_dir: Cache directory for downloaded models
|
|
|
|
Returns:
|
|
Tuple of (model, tokenizer, config)
|
|
|
|
Raises:
|
|
RuntimeError: If model loading fails after all attempts
|
|
"""
|
|
config = get_config()
|
|
model_source = config.get('model_paths', 'source', 'huggingface')
|
|
|
|
# Track loading attempts for debugging
|
|
attempts = []
|
|
|
|
# Try primary loading method
|
|
try:
|
|
if model_source == 'huggingface':
|
|
return _load_from_huggingface(model_name_or_path, cache_dir)
|
|
elif model_source == 'modelscope':
|
|
return _load_from_modelscope(model_name_or_path, cache_dir)
|
|
else:
|
|
raise ValueError(f"Unknown model source: {model_source}")
|
|
except Exception as e:
|
|
attempts.append(f"{model_source}: {str(e)}")
|
|
logger.warning(f"Primary loading from {model_source} failed: {e}")
|
|
|
|
# Fallback strategies
|
|
fallback_sources = ['huggingface'] if model_source == 'modelscope' else ['modelscope']
|
|
|
|
for fallback in fallback_sources:
|
|
try:
|
|
logger.info(f"Attempting fallback loading from {fallback}")
|
|
if fallback == 'huggingface':
|
|
return _load_from_huggingface(model_name_or_path, cache_dir)
|
|
elif fallback == 'modelscope':
|
|
return _load_from_modelscope(model_name_or_path, cache_dir)
|
|
except Exception as e:
|
|
attempts.append(f"{fallback}: {str(e)}")
|
|
logger.warning(f"Fallback loading from {fallback} failed: {e}")
|
|
|
|
# Try loading from local path if it exists
|
|
if os.path.exists(model_name_or_path):
|
|
try:
|
|
logger.info(f"Attempting to load from local path: {model_name_or_path}")
|
|
return _load_from_local(model_name_or_path, cache_dir)
|
|
except Exception as e:
|
|
attempts.append(f"local: {str(e)}")
|
|
logger.warning(f"Local loading failed: {e}")
|
|
|
|
# All attempts failed
|
|
error_msg = f"Failed to load model from any source. Attempts:\n" + "\n".join(attempts)
|
|
logger.error(error_msg)
|
|
raise RuntimeError(error_msg)
|
|
|
|
|
|
def _load_from_huggingface(model_name_or_path, cache_dir=None):
|
|
"""Load model from HuggingFace with error handling"""
|
|
try:
|
|
from transformers import AutoModel, AutoTokenizer, AutoConfig
|
|
|
|
# Load config first to validate model
|
|
model_config = AutoConfig.from_pretrained(
|
|
model_name_or_path,
|
|
cache_dir=cache_dir,
|
|
trust_remote_code=True
|
|
)
|
|
|
|
# Load model
|
|
model = AutoModel.from_pretrained(
|
|
model_name_or_path,
|
|
config=model_config,
|
|
cache_dir=cache_dir,
|
|
trust_remote_code=True
|
|
)
|
|
|
|
# Load tokenizer
|
|
tokenizer = AutoTokenizer.from_pretrained(
|
|
model_name_or_path,
|
|
cache_dir=cache_dir,
|
|
trust_remote_code=True
|
|
)
|
|
|
|
logger.info(f"Successfully loaded model from HuggingFace: {model_name_or_path}")
|
|
return model, tokenizer, model_config
|
|
|
|
except ImportError as e:
|
|
raise ImportError(f"Transformers library not available: {e}")
|
|
except Exception as e:
|
|
raise RuntimeError(f"HuggingFace loading error: {e}")
|
|
|
|
|
|
def _load_from_modelscope(model_name_or_path, cache_dir=None):
|
|
"""Load model from ModelScope with error handling"""
|
|
try:
|
|
from modelscope.models import Model as MSModel
|
|
from modelscope.hub.snapshot_download import snapshot_download as ms_snapshot_download
|
|
|
|
# Download model if needed
|
|
local_path = ms_snapshot_download(model_id=model_name_or_path, cache_dir=cache_dir)
|
|
|
|
# Load model
|
|
model = MSModel.from_pretrained(local_path)
|
|
|
|
# Try to load tokenizer
|
|
tokenizer = None
|
|
try:
|
|
from transformers import AutoTokenizer
|
|
tokenizer = AutoTokenizer.from_pretrained(local_path)
|
|
except Exception as e:
|
|
logger.warning(f"Could not load tokenizer from ModelScope: {e}")
|
|
|
|
# Try to get config
|
|
config = None
|
|
config_path = os.path.join(local_path, 'config.json')
|
|
if os.path.exists(config_path):
|
|
import json
|
|
with open(config_path, 'r') as f:
|
|
config = json.load(f)
|
|
|
|
logger.info(f"Successfully loaded model from ModelScope: {model_name_or_path}")
|
|
return model, tokenizer, config
|
|
|
|
except ImportError as e:
|
|
raise ImportError(f"ModelScope library not available: {e}")
|
|
except Exception as e:
|
|
raise RuntimeError(f"ModelScope loading error: {e}")
|
|
|
|
|
|
def _load_from_local(model_path, cache_dir=None):
|
|
"""Load model from local directory with error handling"""
|
|
try:
|
|
from transformers import AutoModel, AutoTokenizer, AutoConfig
|
|
|
|
if not os.path.isdir(model_path):
|
|
raise ValueError(f"Local path is not a directory: {model_path}")
|
|
|
|
# Check for required files
|
|
required_files = ['config.json']
|
|
missing_files = [f for f in required_files if not os.path.exists(os.path.join(model_path, f))]
|
|
if missing_files:
|
|
raise FileNotFoundError(f"Missing required files: {missing_files}")
|
|
|
|
# Load components
|
|
model_config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
|
|
model = AutoModel.from_pretrained(model_path, config=model_config, trust_remote_code=True)
|
|
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
|
|
|
logger.info(f"Successfully loaded model from local path: {model_path}")
|
|
return model, tokenizer, model_config
|
|
|
|
except Exception as e:
|
|
raise RuntimeError(f"Local loading error: {e}")
|
|
|
|
|
|
class BGEM3Model(nn.Module):
|
|
"""
|
|
BGE-M3 model wrapper supporting dense, sparse, and multi-vector retrieval
|
|
|
|
Notes:
|
|
- All parameters (model_name_or_path, use_dense, use_sparse, use_colbert, pooling_method, etc.) should be passed from config-driven training scripts.
|
|
- This class does not load config directly; pass config values from your main script for consistency.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
model_name_or_path: str = "BAAI/bge-m3",
|
|
use_dense: bool = True,
|
|
use_sparse: bool = True,
|
|
use_colbert: bool = True,
|
|
pooling_method: str = "cls",
|
|
normalize_embeddings: bool = True,
|
|
sentence_pooling_method: str = "cls",
|
|
colbert_dim: int = -1,
|
|
sparse_top_k: int = 100,
|
|
cache_dir: Optional[str] = None,
|
|
device: Optional[str] = None,
|
|
gradient_checkpointing: bool = False
|
|
):
|
|
super().__init__()
|
|
|
|
self.model_name_or_path = model_name_or_path
|
|
self.use_dense = use_dense
|
|
self.use_sparse = use_sparse
|
|
self.use_colbert = use_colbert
|
|
self.pooling_method = pooling_method
|
|
self.normalize_embeddings = normalize_embeddings
|
|
self.colbert_dim = colbert_dim
|
|
self.sparse_top_k = sparse_top_k
|
|
self.gradient_checkpointing = gradient_checkpointing
|
|
|
|
# Initialize model with comprehensive error handling
|
|
try:
|
|
logger.info(f"Loading BGE-M3 model from {model_name_or_path} using source from config.toml...")
|
|
self.model, self.tokenizer, self.config = load_model_and_tokenizer(
|
|
model_name_or_path,
|
|
cache_dir=cache_dir
|
|
)
|
|
logger.info("BGE-M3 model loaded successfully")
|
|
except Exception as e:
|
|
logger.error(f"Failed to load BGE-M3 model: {e}")
|
|
raise RuntimeError(f"Model initialization failed: {e}")
|
|
|
|
# Enable gradient checkpointing if requested
|
|
if self.gradient_checkpointing:
|
|
self.enable_gradient_checkpointing()
|
|
|
|
# Ensure we have a valid config
|
|
if self.config is None:
|
|
logger.warning("Model config is None, creating a basic config")
|
|
# Create a basic config object
|
|
class BasicConfig:
|
|
def __init__(self):
|
|
self.hidden_size = 768 # Default BERT hidden size
|
|
self.config = BasicConfig()
|
|
|
|
# Initialize heads for different functionalities
|
|
hidden_size = getattr(self.config, 'hidden_size', 768)
|
|
|
|
if self.use_dense:
|
|
# Dense retrieval doesn't need additional layers (uses CLS token)
|
|
self.dense_linear = nn.Identity()
|
|
|
|
if self.use_sparse:
|
|
# Sparse retrieval head (SPLADE-style)
|
|
self.sparse_linear = nn.Linear(hidden_size, 1)
|
|
self.sparse_activation = nn.ReLU()
|
|
|
|
if self.use_colbert:
|
|
# ColBERT projection layer
|
|
colbert_dim = colbert_dim if colbert_dim > 0 else hidden_size
|
|
self.colbert_linear = nn.Linear(hidden_size, colbert_dim)
|
|
|
|
# Set device with enhanced detection
|
|
if device:
|
|
self.device = torch.device(device)
|
|
else:
|
|
self.device = self._detect_device()
|
|
|
|
try:
|
|
self.to(self.device)
|
|
logger.info(f"BGE-M3 model moved to device: {self.device}")
|
|
except Exception as e:
|
|
logger.warning(f"Failed to move model to {self.device}: {e}, using CPU")
|
|
self.device = torch.device('cpu')
|
|
self.to(self.device)
|
|
|
|
# Cache for tokenizer
|
|
self._tokenizer = None
|
|
|
|
def enable_gradient_checkpointing(self):
|
|
"""Enable gradient checkpointing for memory efficiency"""
|
|
if hasattr(self.model, 'gradient_checkpointing_enable'):
|
|
self.model.gradient_checkpointing_enable()
|
|
logger.info("Gradient checkpointing enabled for BGE-M3")
|
|
else:
|
|
logger.warning("Model does not support gradient checkpointing")
|
|
|
|
def disable_gradient_checkpointing(self):
|
|
"""Disable gradient checkpointing"""
|
|
if hasattr(self.model, 'gradient_checkpointing_disable'):
|
|
self.model.gradient_checkpointing_disable()
|
|
logger.info("Gradient checkpointing disabled for BGE-M3")
|
|
|
|
def _detect_device(self) -> torch.device:
|
|
"""Enhanced device detection with proper error handling"""
|
|
# Check Ascend NPU first
|
|
try:
|
|
import torch_npu
|
|
if hasattr(torch, 'npu') and torch.npu.is_available(): # type: ignore[attr-defined]
|
|
logger.info("Using Ascend NPU")
|
|
return torch.device("npu")
|
|
except ImportError:
|
|
pass
|
|
except Exception as e:
|
|
logger.debug(f"NPU detection failed: {e}")
|
|
pass
|
|
|
|
# Check CUDA
|
|
if torch.cuda.is_available():
|
|
logger.info("Using CUDA GPU")
|
|
return torch.device("cuda")
|
|
|
|
# Check MPS (Apple Silicon)
|
|
try:
|
|
if hasattr(torch, 'backends') and hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): # type: ignore[attr-defined]
|
|
logger.info("Using Apple MPS")
|
|
return torch.device("mps")
|
|
except Exception as e:
|
|
logger.debug(f"MPS detection failed: {e}")
|
|
pass
|
|
|
|
# Default to CPU
|
|
logger.info("Using CPU")
|
|
return torch.device("cpu")
|
|
|
|
def _get_tokenizer(self):
|
|
"""Lazy load tokenizer with error handling"""
|
|
if self._tokenizer is None:
|
|
try:
|
|
self._tokenizer = AutoTokenizer.from_pretrained(
|
|
self.model_name_or_path,
|
|
trust_remote_code=True
|
|
)
|
|
logger.info("Tokenizer loaded successfully")
|
|
except Exception as e:
|
|
logger.error(f"Failed to load tokenizer: {e}")
|
|
raise RuntimeError(f"Tokenizer loading failed: {e}")
|
|
return self._tokenizer
|
|
|
|
def forward(
|
|
self,
|
|
input_ids: torch.Tensor,
|
|
attention_mask: torch.Tensor,
|
|
token_type_ids: Optional[torch.Tensor] = None,
|
|
return_dense: bool = True,
|
|
return_sparse: bool = False,
|
|
return_colbert: bool = False,
|
|
return_dict: bool = True
|
|
) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
|
|
"""
|
|
Forward pass for BGE-M3 with dense, sparse (SPLADE+), and ColBERT heads.
|
|
Args:
|
|
input_ids: Token IDs [batch_size, seq_length]
|
|
attention_mask: Attention mask [batch_size, seq_length]
|
|
token_type_ids: Token type IDs (optional)
|
|
return_dense: Return dense embeddings
|
|
return_sparse: Return sparse embeddings
|
|
return_colbert: Return ColBERT embeddings
|
|
return_dict: Return dictionary of outputs
|
|
Returns:
|
|
Embeddings or dictionary of embeddings
|
|
"""
|
|
# Validate inputs
|
|
if input_ids is None or attention_mask is None:
|
|
raise ValueError("input_ids and attention_mask are required")
|
|
if input_ids.shape != attention_mask.shape:
|
|
raise ValueError("input_ids and attention_mask must have the same shape")
|
|
|
|
# Determine what to return based on model configuration and parameters
|
|
return_dense = return_dense and self.use_dense
|
|
return_sparse = return_sparse and self.use_sparse
|
|
return_colbert = return_colbert and self.use_colbert
|
|
|
|
if not (return_dense or return_sparse or return_colbert):
|
|
logger.warning("No output head selected in forward. At least one of dense, sparse, or colbert must be True.")
|
|
return_dense = True # Default to dense if nothing selected
|
|
|
|
try:
|
|
outputs = self.model(
|
|
input_ids=input_ids,
|
|
attention_mask=attention_mask,
|
|
token_type_ids=token_type_ids,
|
|
return_dict=True
|
|
)
|
|
last_hidden_state = outputs.last_hidden_state # [batch_size, seq_length, hidden_size]
|
|
except Exception as e:
|
|
logger.error(f"Model forward pass failed: {e}")
|
|
raise RuntimeError(f"Forward pass failed: {e}")
|
|
|
|
results = {}
|
|
|
|
# Dense embeddings (CLS or mean pooling)
|
|
if return_dense:
|
|
if self.pooling_method == "cls" or self.sentence_pooling_method == "cls":
|
|
dense_vecs = last_hidden_state[:, 0]
|
|
elif self.pooling_method == "mean" or self.sentence_pooling_method == "mean":
|
|
expanded_mask = attention_mask.unsqueeze(-1).expand(last_hidden_state.shape)
|
|
masked_embeddings = last_hidden_state * expanded_mask
|
|
sum_embeddings = torch.sum(masked_embeddings, dim=1)
|
|
sum_mask = torch.clamp(attention_mask.sum(dim=1, keepdim=True), min=1e-9)
|
|
dense_vecs = sum_embeddings / sum_mask
|
|
else:
|
|
raise ValueError(f"Unknown pooling method: {self.pooling_method}")
|
|
|
|
dense_vecs = self.dense_linear(dense_vecs)
|
|
# Ensure embeddings are float type
|
|
dense_vecs = dense_vecs.float()
|
|
if self.normalize_embeddings:
|
|
dense_vecs = F.normalize(dense_vecs, p=2, dim=-1)
|
|
results['dense'] = dense_vecs
|
|
|
|
# SPLADE+ sparse head
|
|
if return_sparse:
|
|
# SPLADE+ uses log1p(ReLU(.))
|
|
sparse_logits = self.sparse_linear(last_hidden_state).squeeze(-1) # [batch, seq_len]
|
|
sparse_activations = torch.log1p(F.relu(sparse_logits))
|
|
|
|
# Mask out padding tokens
|
|
sparse_activations = sparse_activations * attention_mask.float()
|
|
|
|
# Optionally normalize sparse activations
|
|
if self.normalize_embeddings:
|
|
sparse_activations = F.normalize(sparse_activations, p=2, dim=-1)
|
|
|
|
results['sparse'] = sparse_activations
|
|
|
|
# FLOPS regularization (for loss, not forward, but add as attribute)
|
|
self.splade_flops = torch.sum(sparse_activations)
|
|
|
|
# ColBERT head with late interaction (maxsim)
|
|
if return_colbert:
|
|
colbert_vecs = self.colbert_linear(last_hidden_state) # [batch, seq_len, colbert_dim]
|
|
|
|
# Mask out padding tokens
|
|
expanded_mask = attention_mask.unsqueeze(-1).expand(colbert_vecs.shape)
|
|
colbert_vecs = colbert_vecs * expanded_mask
|
|
|
|
if self.normalize_embeddings:
|
|
colbert_vecs = F.normalize(colbert_vecs, p=2, dim=-1)
|
|
|
|
results['colbert'] = colbert_vecs
|
|
|
|
if return_dict:
|
|
return results
|
|
|
|
# If only one head requested, return that
|
|
if return_dense and len(results) == 1:
|
|
return results['dense']
|
|
if return_sparse and len(results) == 1:
|
|
return results['sparse']
|
|
if return_colbert and len(results) == 1:
|
|
return results['colbert']
|
|
|
|
return results
|
|
|
|
def encode(
|
|
self,
|
|
sentences: Union[str, List[str]],
|
|
batch_size: int = 32,
|
|
max_length: int = 512,
|
|
convert_to_numpy: bool = True,
|
|
convert_to_tensor: bool = False,
|
|
show_progress_bar: bool = False,
|
|
normalize_embeddings: bool = True,
|
|
return_dense: bool = True,
|
|
return_sparse: bool = False,
|
|
return_colbert: bool = False,
|
|
device: Optional[str] = None
|
|
) -> Union[np.ndarray, torch.Tensor, Dict[str, Union[np.ndarray, torch.Tensor]]]:
|
|
"""
|
|
Complete encode implementation with comprehensive error handling
|
|
|
|
Args:
|
|
sentences: Input sentences
|
|
batch_size: Batch size for encoding
|
|
max_length: Maximum sequence length
|
|
convert_to_numpy: Convert to numpy arrays
|
|
convert_to_tensor: Keep as tensors
|
|
show_progress_bar: Show progress bar
|
|
normalize_embeddings: Normalize embeddings
|
|
return_dense/sparse/colbert: Which embeddings to return
|
|
device: Device to use
|
|
|
|
Returns:
|
|
Embeddings
|
|
"""
|
|
# Input validation
|
|
if isinstance(sentences, str):
|
|
sentences = [sentences]
|
|
|
|
if not sentences:
|
|
raise ValueError("Empty sentences list provided")
|
|
|
|
if batch_size <= 0:
|
|
raise ValueError("batch_size must be positive")
|
|
|
|
if max_length <= 0:
|
|
raise ValueError("max_length must be positive")
|
|
|
|
# Get tokenizer
|
|
try:
|
|
tokenizer = self._get_tokenizer()
|
|
except Exception as e:
|
|
raise RuntimeError(f"Failed to load tokenizer: {e}")
|
|
|
|
# Determine target device
|
|
target_device = device if device else self.device
|
|
|
|
# Initialize storage for embeddings
|
|
all_embeddings = {
|
|
'dense': [] if return_dense else None,
|
|
'sparse': [] if return_sparse else None,
|
|
'colbert': [] if return_colbert else None
|
|
}
|
|
|
|
# Set model to evaluation mode
|
|
self.eval()
|
|
|
|
try:
|
|
# Import tqdm if show_progress_bar is True
|
|
if show_progress_bar:
|
|
try:
|
|
from tqdm import tqdm
|
|
iterator = tqdm(
|
|
range(0, len(sentences), batch_size),
|
|
desc="Encoding",
|
|
total=(len(sentences) + batch_size - 1) // batch_size
|
|
)
|
|
except ImportError:
|
|
logger.warning("tqdm not available, disabling progress bar")
|
|
iterator = range(0, len(sentences), batch_size)
|
|
else:
|
|
iterator = range(0, len(sentences), batch_size)
|
|
|
|
# Process in batches
|
|
for i in iterator:
|
|
batch_sentences = sentences[i:i + batch_size]
|
|
|
|
# Tokenize with proper error handling
|
|
try:
|
|
encoded = tokenizer(
|
|
batch_sentences,
|
|
padding=True,
|
|
truncation=True,
|
|
max_length=max_length,
|
|
return_tensors='pt'
|
|
)
|
|
except Exception as e:
|
|
raise RuntimeError(f"Tokenization failed for batch {i//batch_size + 1}: {e}")
|
|
|
|
# Move to device
|
|
try:
|
|
encoded = {k: v.to(target_device) for k, v in encoded.items()}
|
|
except Exception as e:
|
|
raise RuntimeError(f"Failed to move batch to device {target_device}: {e}")
|
|
|
|
# Get embeddings
|
|
with torch.no_grad():
|
|
try:
|
|
outputs = self.forward(
|
|
**encoded,
|
|
return_dense=return_dense,
|
|
return_sparse=return_sparse,
|
|
return_colbert=return_colbert,
|
|
return_dict=True
|
|
)
|
|
except Exception as e:
|
|
raise RuntimeError(f"Forward pass failed for batch {i//batch_size + 1}: {e}")
|
|
|
|
# Collect embeddings
|
|
for key in ['dense', 'sparse', 'colbert']:
|
|
if all_embeddings[key] is not None and key in outputs:
|
|
all_embeddings[key].append(outputs[key].cpu()) # type: ignore[arg-type]
|
|
|
|
# Concatenate and convert results
|
|
final_embeddings = {}
|
|
for key in ['dense', 'sparse', 'colbert']:
|
|
if all_embeddings[key] is not None and len(all_embeddings[key]) > 0: # type: ignore[arg-type]
|
|
try:
|
|
embeddings = torch.cat(all_embeddings[key], dim=0) # type: ignore[arg-type]
|
|
|
|
if normalize_embeddings and key == 'dense':
|
|
embeddings = F.normalize(embeddings, p=2, dim=-1)
|
|
|
|
if convert_to_numpy:
|
|
embeddings = embeddings.numpy()
|
|
elif not convert_to_tensor:
|
|
embeddings = embeddings
|
|
|
|
final_embeddings[key] = embeddings
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to process {key} embeddings: {e}")
|
|
raise
|
|
|
|
# Return format
|
|
if len(final_embeddings) == 1:
|
|
return list(final_embeddings.values())[0]
|
|
|
|
if not final_embeddings:
|
|
raise RuntimeError("No embeddings were generated")
|
|
|
|
return final_embeddings
|
|
|
|
except Exception as e:
|
|
logger.error(f"Encoding failed: {e}")
|
|
raise
|
|
|
|
def save_pretrained(self, save_path: str):
|
|
"""Save model to directory with error handling, including custom heads and config."""
|
|
try:
|
|
import os
|
|
import json
|
|
os.makedirs(save_path, exist_ok=True)
|
|
|
|
# Save the base model
|
|
self.model.save_pretrained(save_path)
|
|
|
|
# Save additional configuration (including [m3] section)
|
|
from utils.config_loader import get_config
|
|
config_dict = get_config().get_section('m3')
|
|
config = {
|
|
'use_dense': self.use_dense,
|
|
'use_sparse': self.use_sparse,
|
|
'use_colbert': self.use_colbert,
|
|
'pooling_method': self.pooling_method,
|
|
'normalize_embeddings': self.normalize_embeddings,
|
|
'colbert_dim': self.colbert_dim,
|
|
'sparse_top_k': self.sparse_top_k,
|
|
'model_name_or_path': self.model_name_or_path,
|
|
'm3_config': config_dict
|
|
}
|
|
with open(f"{save_path}/m3_config.json", 'w') as f:
|
|
json.dump(config, f, indent=2)
|
|
|
|
# Save custom heads if needed
|
|
if self.use_sparse:
|
|
torch.save(self.sparse_linear.state_dict(), os.path.join(save_path, 'sparse_linear.pt'))
|
|
if self.use_colbert:
|
|
torch.save(self.colbert_linear.state_dict(), os.path.join(save_path, 'colbert_linear.pt'))
|
|
|
|
# Save tokenizer if available
|
|
if self._tokenizer is not None:
|
|
self._tokenizer.save_pretrained(save_path)
|
|
|
|
logger.info(f"Model and config saved to {save_path}")
|
|
except Exception as e:
|
|
logger.error(f"Failed to save model: {e}")
|
|
raise
|
|
|
|
@classmethod
|
|
def from_pretrained(cls, model_path: str, **kwargs):
|
|
"""Load model from directory, restoring custom heads and config."""
|
|
import os
|
|
import json
|
|
try:
|
|
# Load config
|
|
config_path = os.path.join(model_path, "m3_config.json")
|
|
if os.path.exists(config_path):
|
|
with open(config_path, 'r') as f:
|
|
config = json.load(f)
|
|
else:
|
|
config = {}
|
|
|
|
# Instantiate model
|
|
model = cls(
|
|
model_name_or_path=config.get('model_name_or_path', model_path),
|
|
use_dense=config.get('use_dense', True),
|
|
use_sparse=config.get('use_sparse', True),
|
|
use_colbert=config.get('use_colbert', True),
|
|
pooling_method=config.get('pooling_method', 'cls'),
|
|
normalize_embeddings=config.get('normalize_embeddings', True),
|
|
colbert_dim=config.get('colbert_dim', -1),
|
|
sparse_top_k=config.get('sparse_top_k', 100),
|
|
**kwargs
|
|
)
|
|
|
|
# Load base model weights
|
|
model.model = model.model.from_pretrained(model_path)
|
|
|
|
# Load custom heads if present
|
|
if model.use_sparse:
|
|
sparse_path = os.path.join(model_path, 'sparse_linear.pt')
|
|
if os.path.exists(sparse_path):
|
|
model.sparse_linear.load_state_dict(torch.load(sparse_path, map_location='cpu'))
|
|
if model.use_colbert:
|
|
colbert_path = os.path.join(model_path, 'colbert_linear.pt')
|
|
if os.path.exists(colbert_path):
|
|
model.colbert_linear.load_state_dict(torch.load(colbert_path, map_location='cpu'))
|
|
|
|
logger.info(f"Loaded BGE-M3 model from {model_path}")
|
|
return model
|
|
except Exception as e:
|
|
logger.error(f"Failed to load model from {model_path}: {e}")
|
|
raise
|
|
|
|
def compute_similarity(
|
|
self,
|
|
queries: Union[torch.Tensor, Dict[str, torch.Tensor]],
|
|
passages: Union[torch.Tensor, Dict[str, torch.Tensor]],
|
|
mode: str = "dense"
|
|
) -> torch.Tensor:
|
|
"""
|
|
Compute similarity scores between queries and passages with error handling
|
|
|
|
Args:
|
|
queries: Query embeddings
|
|
passages: Passage embeddings
|
|
mode: Similarity mode ('dense', 'sparse', 'colbert')
|
|
|
|
Returns:
|
|
Similarity scores [num_queries, num_passages]
|
|
"""
|
|
try:
|
|
# Extract embeddings based on mode
|
|
if isinstance(queries, dict):
|
|
if mode not in queries:
|
|
raise ValueError(f"Mode '{mode}' not found in query embeddings")
|
|
query_emb = queries[mode]
|
|
else:
|
|
query_emb = queries
|
|
|
|
if isinstance(passages, dict):
|
|
if mode not in passages:
|
|
raise ValueError(f"Mode '{mode}' not found in passage embeddings")
|
|
passage_emb = passages[mode]
|
|
else:
|
|
passage_emb = passages
|
|
|
|
# Validate tensor shapes
|
|
if query_emb.dim() < 2 or passage_emb.dim() < 2:
|
|
raise ValueError("Embeddings must be at least 2-dimensional")
|
|
|
|
if mode == "dense":
|
|
# Simple dot product for dense
|
|
return torch.matmul(query_emb, passage_emb.t())
|
|
|
|
elif mode == "sparse":
|
|
# Dot product for sparse (simplified)
|
|
return torch.matmul(query_emb, passage_emb.t())
|
|
|
|
elif mode == "colbert":
|
|
# MaxSim for ColBERT
|
|
if query_emb.dim() != 3 or passage_emb.dim() != 3:
|
|
raise ValueError("ColBERT embeddings must be 3-dimensional [batch, seq_len, dim]")
|
|
|
|
scores = []
|
|
for q in query_emb:
|
|
# q: [query_len, dim]
|
|
# passage_emb: [num_passages, passage_len, dim]
|
|
passage_scores = []
|
|
for p in passage_emb:
|
|
# Compute similarity matrix
|
|
sim_matrix = torch.matmul(q, p.t()) # [query_len, passage_len]
|
|
# Max over passage tokens for each query token
|
|
max_sim_scores = sim_matrix.max(dim=1)[0] # [query_len]
|
|
# Sum over query tokens
|
|
score = max_sim_scores.sum()
|
|
passage_scores.append(score)
|
|
scores.append(torch.stack(passage_scores))
|
|
return torch.stack(scores)
|
|
|
|
else:
|
|
raise ValueError(f"Unknown similarity mode: {mode}")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Similarity computation failed: {e}")
|
|
raise
|
|
|
|
def get_sentence_embedding_dimension(self) -> int:
|
|
"""Get the dimension of sentence embeddings"""
|
|
return self.config.hidden_size # type: ignore[attr-defined]
|
|
|
|
def get_word_embedding_dimension(self) -> int:
|
|
"""Get the dimension of word embeddings"""
|
|
return self.config.hidden_size # type: ignore[attr-defined]
|