2025-07-22 16:55:25 +08:00

769 lines
30 KiB
Python

"""
BGE-M3 model wrapper for multi-functionality embedding
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoModel, AutoTokenizer, AutoConfig
from typing import Dict, List, Optional, Union, Tuple
import numpy as np
import logging
from utils.config_loader import get_config
import os
logger = logging.getLogger(__name__)
def load_model_and_tokenizer(model_name_or_path, cache_dir=None):
"""
Load model and tokenizer with comprehensive error handling and fallback strategies
Args:
model_name_or_path: Model identifier or local path
cache_dir: Cache directory for downloaded models
Returns:
Tuple of (model, tokenizer, config)
Raises:
RuntimeError: If model loading fails after all attempts
"""
config = get_config()
model_source = config.get('model_paths', 'source', 'huggingface')
# Track loading attempts for debugging
attempts = []
# Try primary loading method
try:
if model_source == 'huggingface':
return _load_from_huggingface(model_name_or_path, cache_dir)
elif model_source == 'modelscope':
return _load_from_modelscope(model_name_or_path, cache_dir)
else:
raise ValueError(f"Unknown model source: {model_source}")
except Exception as e:
attempts.append(f"{model_source}: {str(e)}")
logger.warning(f"Primary loading from {model_source} failed: {e}")
# Fallback strategies
fallback_sources = ['huggingface'] if model_source == 'modelscope' else ['modelscope']
for fallback in fallback_sources:
try:
logger.info(f"Attempting fallback loading from {fallback}")
if fallback == 'huggingface':
return _load_from_huggingface(model_name_or_path, cache_dir)
elif fallback == 'modelscope':
return _load_from_modelscope(model_name_or_path, cache_dir)
except Exception as e:
attempts.append(f"{fallback}: {str(e)}")
logger.warning(f"Fallback loading from {fallback} failed: {e}")
# Try loading from local path if it exists
if os.path.exists(model_name_or_path):
try:
logger.info(f"Attempting to load from local path: {model_name_or_path}")
return _load_from_local(model_name_or_path, cache_dir)
except Exception as e:
attempts.append(f"local: {str(e)}")
logger.warning(f"Local loading failed: {e}")
# All attempts failed
error_msg = f"Failed to load model from any source. Attempts:\n" + "\n".join(attempts)
logger.error(error_msg)
raise RuntimeError(error_msg)
def _load_from_huggingface(model_name_or_path, cache_dir=None):
"""Load model from HuggingFace with error handling"""
try:
from transformers import AutoModel, AutoTokenizer, AutoConfig
# Load config first to validate model
model_config = AutoConfig.from_pretrained(
model_name_or_path,
cache_dir=cache_dir,
trust_remote_code=True
)
# Load model
model = AutoModel.from_pretrained(
model_name_or_path,
config=model_config,
cache_dir=cache_dir,
trust_remote_code=True
)
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(
model_name_or_path,
cache_dir=cache_dir,
trust_remote_code=True
)
logger.info(f"Successfully loaded model from HuggingFace: {model_name_or_path}")
return model, tokenizer, model_config
except ImportError as e:
raise ImportError(f"Transformers library not available: {e}")
except Exception as e:
raise RuntimeError(f"HuggingFace loading error: {e}")
def _load_from_modelscope(model_name_or_path, cache_dir=None):
"""Load model from ModelScope with error handling"""
try:
from modelscope.models import Model as MSModel
from modelscope.hub.snapshot_download import snapshot_download as ms_snapshot_download
# Download model if needed
local_path = ms_snapshot_download(model_id=model_name_or_path, cache_dir=cache_dir)
# Load model
model = MSModel.from_pretrained(local_path)
# Try to load tokenizer
tokenizer = None
try:
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(local_path)
except Exception as e:
logger.warning(f"Could not load tokenizer from ModelScope: {e}")
# Try to get config
config = None
config_path = os.path.join(local_path, 'config.json')
if os.path.exists(config_path):
import json
with open(config_path, 'r') as f:
config = json.load(f)
logger.info(f"Successfully loaded model from ModelScope: {model_name_or_path}")
return model, tokenizer, config
except ImportError as e:
raise ImportError(f"ModelScope library not available: {e}")
except Exception as e:
raise RuntimeError(f"ModelScope loading error: {e}")
def _load_from_local(model_path, cache_dir=None):
"""Load model from local directory with error handling"""
try:
from transformers import AutoModel, AutoTokenizer, AutoConfig
if not os.path.isdir(model_path):
raise ValueError(f"Local path is not a directory: {model_path}")
# Check for required files
required_files = ['config.json']
missing_files = [f for f in required_files if not os.path.exists(os.path.join(model_path, f))]
if missing_files:
raise FileNotFoundError(f"Missing required files: {missing_files}")
# Load components
model_config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
model = AutoModel.from_pretrained(model_path, config=model_config, trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
logger.info(f"Successfully loaded model from local path: {model_path}")
return model, tokenizer, model_config
except Exception as e:
raise RuntimeError(f"Local loading error: {e}")
class BGEM3Model(nn.Module):
"""
BGE-M3 model wrapper supporting dense, sparse, and multi-vector retrieval
Notes:
- All parameters (model_name_or_path, use_dense, use_sparse, use_colbert, pooling_method, etc.) should be passed from config-driven training scripts.
- This class does not load config directly; pass config values from your main script for consistency.
"""
def __init__(
self,
model_name_or_path: str = "BAAI/bge-m3",
use_dense: bool = True,
use_sparse: bool = True,
use_colbert: bool = True,
pooling_method: str = "cls",
normalize_embeddings: bool = True,
sentence_pooling_method: str = "cls",
colbert_dim: int = -1,
sparse_top_k: int = 100,
cache_dir: Optional[str] = None,
device: Optional[str] = None,
gradient_checkpointing: bool = False
):
super().__init__()
self.model_name_or_path = model_name_or_path
self.use_dense = use_dense
self.use_sparse = use_sparse
self.use_colbert = use_colbert
self.pooling_method = pooling_method
self.normalize_embeddings = normalize_embeddings
self.colbert_dim = colbert_dim
self.sparse_top_k = sparse_top_k
self.gradient_checkpointing = gradient_checkpointing
# Initialize model with comprehensive error handling
try:
logger.info(f"Loading BGE-M3 model from {model_name_or_path} using source from config.toml...")
self.model, self.tokenizer, self.config = load_model_and_tokenizer(
model_name_or_path,
cache_dir=cache_dir
)
logger.info("BGE-M3 model loaded successfully")
except Exception as e:
logger.error(f"Failed to load BGE-M3 model: {e}")
raise RuntimeError(f"Model initialization failed: {e}")
# Enable gradient checkpointing if requested
if self.gradient_checkpointing:
self.enable_gradient_checkpointing()
# Ensure we have a valid config
if self.config is None:
logger.warning("Model config is None, creating a basic config")
# Create a basic config object
class BasicConfig:
def __init__(self):
self.hidden_size = 768 # Default BERT hidden size
self.config = BasicConfig()
# Initialize heads for different functionalities
hidden_size = getattr(self.config, 'hidden_size', 768)
if self.use_dense:
# Dense retrieval doesn't need additional layers (uses CLS token)
self.dense_linear = nn.Identity()
if self.use_sparse:
# Sparse retrieval head (SPLADE-style)
self.sparse_linear = nn.Linear(hidden_size, 1)
self.sparse_activation = nn.ReLU()
if self.use_colbert:
# ColBERT projection layer
colbert_dim = colbert_dim if colbert_dim > 0 else hidden_size
self.colbert_linear = nn.Linear(hidden_size, colbert_dim)
# Set device with enhanced detection
if device:
self.device = torch.device(device)
else:
self.device = self._detect_device()
try:
self.to(self.device)
logger.info(f"BGE-M3 model moved to device: {self.device}")
except Exception as e:
logger.warning(f"Failed to move model to {self.device}: {e}, using CPU")
self.device = torch.device('cpu')
self.to(self.device)
# Cache for tokenizer
self._tokenizer = None
def enable_gradient_checkpointing(self):
"""Enable gradient checkpointing for memory efficiency"""
if hasattr(self.model, 'gradient_checkpointing_enable'):
self.model.gradient_checkpointing_enable()
logger.info("Gradient checkpointing enabled for BGE-M3")
else:
logger.warning("Model does not support gradient checkpointing")
def disable_gradient_checkpointing(self):
"""Disable gradient checkpointing"""
if hasattr(self.model, 'gradient_checkpointing_disable'):
self.model.gradient_checkpointing_disable()
logger.info("Gradient checkpointing disabled for BGE-M3")
def _detect_device(self) -> torch.device:
"""Enhanced device detection with proper error handling"""
# Check Ascend NPU first
try:
import torch_npu
if hasattr(torch, 'npu') and torch.npu.is_available(): # type: ignore[attr-defined]
logger.info("Using Ascend NPU")
return torch.device("npu")
except ImportError:
pass
except Exception as e:
logger.debug(f"NPU detection failed: {e}")
pass
# Check CUDA
if torch.cuda.is_available():
logger.info("Using CUDA GPU")
return torch.device("cuda")
# Check MPS (Apple Silicon)
try:
if hasattr(torch, 'backends') and hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): # type: ignore[attr-defined]
logger.info("Using Apple MPS")
return torch.device("mps")
except Exception as e:
logger.debug(f"MPS detection failed: {e}")
pass
# Default to CPU
logger.info("Using CPU")
return torch.device("cpu")
def _get_tokenizer(self):
"""Lazy load tokenizer with error handling"""
if self._tokenizer is None:
try:
self._tokenizer = AutoTokenizer.from_pretrained(
self.model_name_or_path,
trust_remote_code=True
)
logger.info("Tokenizer loaded successfully")
except Exception as e:
logger.error(f"Failed to load tokenizer: {e}")
raise RuntimeError(f"Tokenizer loading failed: {e}")
return self._tokenizer
def forward(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
token_type_ids: Optional[torch.Tensor] = None,
return_dense: bool = True,
return_sparse: bool = False,
return_colbert: bool = False,
return_dict: bool = True
) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
"""
Forward pass for BGE-M3 with dense, sparse (SPLADE+), and ColBERT heads.
Args:
input_ids: Token IDs [batch_size, seq_length]
attention_mask: Attention mask [batch_size, seq_length]
token_type_ids: Token type IDs (optional)
return_dense: Return dense embeddings
return_sparse: Return sparse embeddings
return_colbert: Return ColBERT embeddings
return_dict: Return dictionary of outputs
Returns:
Embeddings or dictionary of embeddings
"""
# Validate inputs
if input_ids is None or attention_mask is None:
raise ValueError("input_ids and attention_mask are required")
if input_ids.shape != attention_mask.shape:
raise ValueError("input_ids and attention_mask must have the same shape")
# Determine what to return based on model configuration and parameters
return_dense = return_dense and self.use_dense
return_sparse = return_sparse and self.use_sparse
return_colbert = return_colbert and self.use_colbert
if not (return_dense or return_sparse or return_colbert):
logger.warning("No output head selected in forward. At least one of dense, sparse, or colbert must be True.")
return_dense = True # Default to dense if nothing selected
try:
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
return_dict=True
)
last_hidden_state = outputs.last_hidden_state # [batch_size, seq_length, hidden_size]
except Exception as e:
logger.error(f"Model forward pass failed: {e}")
raise RuntimeError(f"Forward pass failed: {e}")
results = {}
# Dense embeddings (CLS or mean pooling)
if return_dense:
if self.pooling_method == "cls" or self.sentence_pooling_method == "cls":
dense_vecs = last_hidden_state[:, 0]
elif self.pooling_method == "mean" or self.sentence_pooling_method == "mean":
expanded_mask = attention_mask.unsqueeze(-1).expand(last_hidden_state.shape)
masked_embeddings = last_hidden_state * expanded_mask
sum_embeddings = torch.sum(masked_embeddings, dim=1)
sum_mask = torch.clamp(attention_mask.sum(dim=1, keepdim=True), min=1e-9)
dense_vecs = sum_embeddings / sum_mask
else:
raise ValueError(f"Unknown pooling method: {self.pooling_method}")
dense_vecs = self.dense_linear(dense_vecs)
# Ensure embeddings are float type
dense_vecs = dense_vecs.float()
if self.normalize_embeddings:
dense_vecs = F.normalize(dense_vecs, p=2, dim=-1)
results['dense'] = dense_vecs
# SPLADE+ sparse head
if return_sparse:
# SPLADE+ uses log1p(ReLU(.))
sparse_logits = self.sparse_linear(last_hidden_state).squeeze(-1) # [batch, seq_len]
sparse_activations = torch.log1p(F.relu(sparse_logits))
# Mask out padding tokens
sparse_activations = sparse_activations * attention_mask.float()
# Optionally normalize sparse activations
if self.normalize_embeddings:
sparse_activations = F.normalize(sparse_activations, p=2, dim=-1)
results['sparse'] = sparse_activations
# FLOPS regularization (for loss, not forward, but add as attribute)
self.splade_flops = torch.sum(sparse_activations)
# ColBERT head with late interaction (maxsim)
if return_colbert:
colbert_vecs = self.colbert_linear(last_hidden_state) # [batch, seq_len, colbert_dim]
# Mask out padding tokens
expanded_mask = attention_mask.unsqueeze(-1).expand(colbert_vecs.shape)
colbert_vecs = colbert_vecs * expanded_mask
if self.normalize_embeddings:
colbert_vecs = F.normalize(colbert_vecs, p=2, dim=-1)
results['colbert'] = colbert_vecs
if return_dict:
return results
# If only one head requested, return that
if return_dense and len(results) == 1:
return results['dense']
if return_sparse and len(results) == 1:
return results['sparse']
if return_colbert and len(results) == 1:
return results['colbert']
return results
def encode(
self,
sentences: Union[str, List[str]],
batch_size: int = 32,
max_length: int = 512,
convert_to_numpy: bool = True,
convert_to_tensor: bool = False,
show_progress_bar: bool = False,
normalize_embeddings: bool = True,
return_dense: bool = True,
return_sparse: bool = False,
return_colbert: bool = False,
device: Optional[str] = None
) -> Union[np.ndarray, torch.Tensor, Dict[str, Union[np.ndarray, torch.Tensor]]]:
"""
Complete encode implementation with comprehensive error handling
Args:
sentences: Input sentences
batch_size: Batch size for encoding
max_length: Maximum sequence length
convert_to_numpy: Convert to numpy arrays
convert_to_tensor: Keep as tensors
show_progress_bar: Show progress bar
normalize_embeddings: Normalize embeddings
return_dense/sparse/colbert: Which embeddings to return
device: Device to use
Returns:
Embeddings
"""
# Input validation
if isinstance(sentences, str):
sentences = [sentences]
if not sentences:
raise ValueError("Empty sentences list provided")
if batch_size <= 0:
raise ValueError("batch_size must be positive")
if max_length <= 0:
raise ValueError("max_length must be positive")
# Get tokenizer
try:
tokenizer = self._get_tokenizer()
except Exception as e:
raise RuntimeError(f"Failed to load tokenizer: {e}")
# Determine target device
target_device = device if device else self.device
# Initialize storage for embeddings
all_embeddings = {
'dense': [] if return_dense else None,
'sparse': [] if return_sparse else None,
'colbert': [] if return_colbert else None
}
# Set model to evaluation mode
self.eval()
try:
# Import tqdm if show_progress_bar is True
if show_progress_bar:
try:
from tqdm import tqdm
iterator = tqdm(
range(0, len(sentences), batch_size),
desc="Encoding",
total=(len(sentences) + batch_size - 1) // batch_size
)
except ImportError:
logger.warning("tqdm not available, disabling progress bar")
iterator = range(0, len(sentences), batch_size)
else:
iterator = range(0, len(sentences), batch_size)
# Process in batches
for i in iterator:
batch_sentences = sentences[i:i + batch_size]
# Tokenize with proper error handling
try:
encoded = tokenizer(
batch_sentences,
padding=True,
truncation=True,
max_length=max_length,
return_tensors='pt'
)
except Exception as e:
raise RuntimeError(f"Tokenization failed for batch {i//batch_size + 1}: {e}")
# Move to device
try:
encoded = {k: v.to(target_device) for k, v in encoded.items()}
except Exception as e:
raise RuntimeError(f"Failed to move batch to device {target_device}: {e}")
# Get embeddings
with torch.no_grad():
try:
outputs = self.forward(
**encoded,
return_dense=return_dense,
return_sparse=return_sparse,
return_colbert=return_colbert,
return_dict=True
)
except Exception as e:
raise RuntimeError(f"Forward pass failed for batch {i//batch_size + 1}: {e}")
# Collect embeddings
for key in ['dense', 'sparse', 'colbert']:
if all_embeddings[key] is not None and key in outputs:
all_embeddings[key].append(outputs[key].cpu()) # type: ignore[arg-type]
# Concatenate and convert results
final_embeddings = {}
for key in ['dense', 'sparse', 'colbert']:
if all_embeddings[key] is not None and len(all_embeddings[key]) > 0: # type: ignore[arg-type]
try:
embeddings = torch.cat(all_embeddings[key], dim=0) # type: ignore[arg-type]
if normalize_embeddings and key == 'dense':
embeddings = F.normalize(embeddings, p=2, dim=-1)
if convert_to_numpy:
embeddings = embeddings.numpy()
elif not convert_to_tensor:
embeddings = embeddings
final_embeddings[key] = embeddings
except Exception as e:
logger.error(f"Failed to process {key} embeddings: {e}")
raise
# Return format
if len(final_embeddings) == 1:
return list(final_embeddings.values())[0]
if not final_embeddings:
raise RuntimeError("No embeddings were generated")
return final_embeddings
except Exception as e:
logger.error(f"Encoding failed: {e}")
raise
def save_pretrained(self, save_path: str):
"""Save model to directory with error handling, including custom heads and config."""
try:
import os
import json
os.makedirs(save_path, exist_ok=True)
# Save the base model
self.model.save_pretrained(save_path)
# Save additional configuration (including [m3] section)
from utils.config_loader import get_config
config_dict = get_config().get_section('m3')
config = {
'use_dense': self.use_dense,
'use_sparse': self.use_sparse,
'use_colbert': self.use_colbert,
'pooling_method': self.pooling_method,
'normalize_embeddings': self.normalize_embeddings,
'colbert_dim': self.colbert_dim,
'sparse_top_k': self.sparse_top_k,
'model_name_or_path': self.model_name_or_path,
'm3_config': config_dict
}
with open(f"{save_path}/m3_config.json", 'w') as f:
json.dump(config, f, indent=2)
# Save custom heads if needed
if self.use_sparse:
torch.save(self.sparse_linear.state_dict(), os.path.join(save_path, 'sparse_linear.pt'))
if self.use_colbert:
torch.save(self.colbert_linear.state_dict(), os.path.join(save_path, 'colbert_linear.pt'))
# Save tokenizer if available
if self._tokenizer is not None:
self._tokenizer.save_pretrained(save_path)
logger.info(f"Model and config saved to {save_path}")
except Exception as e:
logger.error(f"Failed to save model: {e}")
raise
@classmethod
def from_pretrained(cls, model_path: str, **kwargs):
"""Load model from directory, restoring custom heads and config."""
import os
import json
try:
# Load config
config_path = os.path.join(model_path, "m3_config.json")
if os.path.exists(config_path):
with open(config_path, 'r') as f:
config = json.load(f)
else:
config = {}
# Instantiate model
model = cls(
model_name_or_path=config.get('model_name_or_path', model_path),
use_dense=config.get('use_dense', True),
use_sparse=config.get('use_sparse', True),
use_colbert=config.get('use_colbert', True),
pooling_method=config.get('pooling_method', 'cls'),
normalize_embeddings=config.get('normalize_embeddings', True),
colbert_dim=config.get('colbert_dim', -1),
sparse_top_k=config.get('sparse_top_k', 100),
**kwargs
)
# Load base model weights
model.model = model.model.from_pretrained(model_path)
# Load custom heads if present
if model.use_sparse:
sparse_path = os.path.join(model_path, 'sparse_linear.pt')
if os.path.exists(sparse_path):
model.sparse_linear.load_state_dict(torch.load(sparse_path, map_location='cpu'))
if model.use_colbert:
colbert_path = os.path.join(model_path, 'colbert_linear.pt')
if os.path.exists(colbert_path):
model.colbert_linear.load_state_dict(torch.load(colbert_path, map_location='cpu'))
logger.info(f"Loaded BGE-M3 model from {model_path}")
return model
except Exception as e:
logger.error(f"Failed to load model from {model_path}: {e}")
raise
def compute_similarity(
self,
queries: Union[torch.Tensor, Dict[str, torch.Tensor]],
passages: Union[torch.Tensor, Dict[str, torch.Tensor]],
mode: str = "dense"
) -> torch.Tensor:
"""
Compute similarity scores between queries and passages with error handling
Args:
queries: Query embeddings
passages: Passage embeddings
mode: Similarity mode ('dense', 'sparse', 'colbert')
Returns:
Similarity scores [num_queries, num_passages]
"""
try:
# Extract embeddings based on mode
if isinstance(queries, dict):
if mode not in queries:
raise ValueError(f"Mode '{mode}' not found in query embeddings")
query_emb = queries[mode]
else:
query_emb = queries
if isinstance(passages, dict):
if mode not in passages:
raise ValueError(f"Mode '{mode}' not found in passage embeddings")
passage_emb = passages[mode]
else:
passage_emb = passages
# Validate tensor shapes
if query_emb.dim() < 2 or passage_emb.dim() < 2:
raise ValueError("Embeddings must be at least 2-dimensional")
if mode == "dense":
# Simple dot product for dense
return torch.matmul(query_emb, passage_emb.t())
elif mode == "sparse":
# Dot product for sparse (simplified)
return torch.matmul(query_emb, passage_emb.t())
elif mode == "colbert":
# MaxSim for ColBERT
if query_emb.dim() != 3 or passage_emb.dim() != 3:
raise ValueError("ColBERT embeddings must be 3-dimensional [batch, seq_len, dim]")
scores = []
for q in query_emb:
# q: [query_len, dim]
# passage_emb: [num_passages, passage_len, dim]
passage_scores = []
for p in passage_emb:
# Compute similarity matrix
sim_matrix = torch.matmul(q, p.t()) # [query_len, passage_len]
# Max over passage tokens for each query token
max_sim_scores = sim_matrix.max(dim=1)[0] # [query_len]
# Sum over query tokens
score = max_sim_scores.sum()
passage_scores.append(score)
scores.append(torch.stack(passage_scores))
return torch.stack(scores)
else:
raise ValueError(f"Unknown similarity mode: {mode}")
except Exception as e:
logger.error(f"Similarity computation failed: {e}")
raise
def get_sentence_embedding_dimension(self) -> int:
"""Get the dimension of sentence embeddings"""
return self.config.hidden_size # type: ignore[attr-defined]
def get_word_embedding_dimension(self) -> int:
"""Get the dimension of word embeddings"""
return self.config.hidden_size # type: ignore[attr-defined]