init commit
This commit is contained in:
768
models/bge_m3.py
Normal file
768
models/bge_m3.py
Normal file
@@ -0,0 +1,768 @@
|
||||
"""
|
||||
BGE-M3 model wrapper for multi-functionality embedding
|
||||
"""
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from transformers import AutoModel, AutoTokenizer, AutoConfig
|
||||
from typing import Dict, List, Optional, Union, Tuple
|
||||
import numpy as np
|
||||
import logging
|
||||
from utils.config_loader import get_config
|
||||
import os
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def load_model_and_tokenizer(model_name_or_path, cache_dir=None):
|
||||
"""
|
||||
Load model and tokenizer with comprehensive error handling and fallback strategies
|
||||
|
||||
Args:
|
||||
model_name_or_path: Model identifier or local path
|
||||
cache_dir: Cache directory for downloaded models
|
||||
|
||||
Returns:
|
||||
Tuple of (model, tokenizer, config)
|
||||
|
||||
Raises:
|
||||
RuntimeError: If model loading fails after all attempts
|
||||
"""
|
||||
config = get_config()
|
||||
model_source = config.get('model_paths', 'source', 'huggingface')
|
||||
|
||||
# Track loading attempts for debugging
|
||||
attempts = []
|
||||
|
||||
# Try primary loading method
|
||||
try:
|
||||
if model_source == 'huggingface':
|
||||
return _load_from_huggingface(model_name_or_path, cache_dir)
|
||||
elif model_source == 'modelscope':
|
||||
return _load_from_modelscope(model_name_or_path, cache_dir)
|
||||
else:
|
||||
raise ValueError(f"Unknown model source: {model_source}")
|
||||
except Exception as e:
|
||||
attempts.append(f"{model_source}: {str(e)}")
|
||||
logger.warning(f"Primary loading from {model_source} failed: {e}")
|
||||
|
||||
# Fallback strategies
|
||||
fallback_sources = ['huggingface'] if model_source == 'modelscope' else ['modelscope']
|
||||
|
||||
for fallback in fallback_sources:
|
||||
try:
|
||||
logger.info(f"Attempting fallback loading from {fallback}")
|
||||
if fallback == 'huggingface':
|
||||
return _load_from_huggingface(model_name_or_path, cache_dir)
|
||||
elif fallback == 'modelscope':
|
||||
return _load_from_modelscope(model_name_or_path, cache_dir)
|
||||
except Exception as e:
|
||||
attempts.append(f"{fallback}: {str(e)}")
|
||||
logger.warning(f"Fallback loading from {fallback} failed: {e}")
|
||||
|
||||
# Try loading from local path if it exists
|
||||
if os.path.exists(model_name_or_path):
|
||||
try:
|
||||
logger.info(f"Attempting to load from local path: {model_name_or_path}")
|
||||
return _load_from_local(model_name_or_path, cache_dir)
|
||||
except Exception as e:
|
||||
attempts.append(f"local: {str(e)}")
|
||||
logger.warning(f"Local loading failed: {e}")
|
||||
|
||||
# All attempts failed
|
||||
error_msg = f"Failed to load model from any source. Attempts:\n" + "\n".join(attempts)
|
||||
logger.error(error_msg)
|
||||
raise RuntimeError(error_msg)
|
||||
|
||||
|
||||
def _load_from_huggingface(model_name_or_path, cache_dir=None):
|
||||
"""Load model from HuggingFace with error handling"""
|
||||
try:
|
||||
from transformers import AutoModel, AutoTokenizer, AutoConfig
|
||||
|
||||
# Load config first to validate model
|
||||
model_config = AutoConfig.from_pretrained(
|
||||
model_name_or_path,
|
||||
cache_dir=cache_dir,
|
||||
trust_remote_code=True
|
||||
)
|
||||
|
||||
# Load model
|
||||
model = AutoModel.from_pretrained(
|
||||
model_name_or_path,
|
||||
config=model_config,
|
||||
cache_dir=cache_dir,
|
||||
trust_remote_code=True
|
||||
)
|
||||
|
||||
# Load tokenizer
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_name_or_path,
|
||||
cache_dir=cache_dir,
|
||||
trust_remote_code=True
|
||||
)
|
||||
|
||||
logger.info(f"Successfully loaded model from HuggingFace: {model_name_or_path}")
|
||||
return model, tokenizer, model_config
|
||||
|
||||
except ImportError as e:
|
||||
raise ImportError(f"Transformers library not available: {e}")
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"HuggingFace loading error: {e}")
|
||||
|
||||
|
||||
def _load_from_modelscope(model_name_or_path, cache_dir=None):
|
||||
"""Load model from ModelScope with error handling"""
|
||||
try:
|
||||
from modelscope.models import Model as MSModel
|
||||
from modelscope.hub.snapshot_download import snapshot_download as ms_snapshot_download
|
||||
|
||||
# Download model if needed
|
||||
local_path = ms_snapshot_download(model_id=model_name_or_path, cache_dir=cache_dir)
|
||||
|
||||
# Load model
|
||||
model = MSModel.from_pretrained(local_path)
|
||||
|
||||
# Try to load tokenizer
|
||||
tokenizer = None
|
||||
try:
|
||||
from transformers import AutoTokenizer
|
||||
tokenizer = AutoTokenizer.from_pretrained(local_path)
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not load tokenizer from ModelScope: {e}")
|
||||
|
||||
# Try to get config
|
||||
config = None
|
||||
config_path = os.path.join(local_path, 'config.json')
|
||||
if os.path.exists(config_path):
|
||||
import json
|
||||
with open(config_path, 'r') as f:
|
||||
config = json.load(f)
|
||||
|
||||
logger.info(f"Successfully loaded model from ModelScope: {model_name_or_path}")
|
||||
return model, tokenizer, config
|
||||
|
||||
except ImportError as e:
|
||||
raise ImportError(f"ModelScope library not available: {e}")
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"ModelScope loading error: {e}")
|
||||
|
||||
|
||||
def _load_from_local(model_path, cache_dir=None):
|
||||
"""Load model from local directory with error handling"""
|
||||
try:
|
||||
from transformers import AutoModel, AutoTokenizer, AutoConfig
|
||||
|
||||
if not os.path.isdir(model_path):
|
||||
raise ValueError(f"Local path is not a directory: {model_path}")
|
||||
|
||||
# Check for required files
|
||||
required_files = ['config.json']
|
||||
missing_files = [f for f in required_files if not os.path.exists(os.path.join(model_path, f))]
|
||||
if missing_files:
|
||||
raise FileNotFoundError(f"Missing required files: {missing_files}")
|
||||
|
||||
# Load components
|
||||
model_config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
|
||||
model = AutoModel.from_pretrained(model_path, config=model_config, trust_remote_code=True)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
||||
|
||||
logger.info(f"Successfully loaded model from local path: {model_path}")
|
||||
return model, tokenizer, model_config
|
||||
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Local loading error: {e}")
|
||||
|
||||
|
||||
class BGEM3Model(nn.Module):
|
||||
"""
|
||||
BGE-M3 model wrapper supporting dense, sparse, and multi-vector retrieval
|
||||
|
||||
Notes:
|
||||
- All parameters (model_name_or_path, use_dense, use_sparse, use_colbert, pooling_method, etc.) should be passed from config-driven training scripts.
|
||||
- This class does not load config directly; pass config values from your main script for consistency.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name_or_path: str = "BAAI/bge-m3",
|
||||
use_dense: bool = True,
|
||||
use_sparse: bool = True,
|
||||
use_colbert: bool = True,
|
||||
pooling_method: str = "cls",
|
||||
normalize_embeddings: bool = True,
|
||||
sentence_pooling_method: str = "cls",
|
||||
colbert_dim: int = -1,
|
||||
sparse_top_k: int = 100,
|
||||
cache_dir: Optional[str] = None,
|
||||
device: Optional[str] = None,
|
||||
gradient_checkpointing: bool = False
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.model_name_or_path = model_name_or_path
|
||||
self.use_dense = use_dense
|
||||
self.use_sparse = use_sparse
|
||||
self.use_colbert = use_colbert
|
||||
self.pooling_method = pooling_method
|
||||
self.normalize_embeddings = normalize_embeddings
|
||||
self.colbert_dim = colbert_dim
|
||||
self.sparse_top_k = sparse_top_k
|
||||
self.gradient_checkpointing = gradient_checkpointing
|
||||
|
||||
# Initialize model with comprehensive error handling
|
||||
try:
|
||||
logger.info(f"Loading BGE-M3 model from {model_name_or_path} using source from config.toml...")
|
||||
self.model, self.tokenizer, self.config = load_model_and_tokenizer(
|
||||
model_name_or_path,
|
||||
cache_dir=cache_dir
|
||||
)
|
||||
logger.info("BGE-M3 model loaded successfully")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load BGE-M3 model: {e}")
|
||||
raise RuntimeError(f"Model initialization failed: {e}")
|
||||
|
||||
# Enable gradient checkpointing if requested
|
||||
if self.gradient_checkpointing:
|
||||
self.enable_gradient_checkpointing()
|
||||
|
||||
# Ensure we have a valid config
|
||||
if self.config is None:
|
||||
logger.warning("Model config is None, creating a basic config")
|
||||
# Create a basic config object
|
||||
class BasicConfig:
|
||||
def __init__(self):
|
||||
self.hidden_size = 768 # Default BERT hidden size
|
||||
self.config = BasicConfig()
|
||||
|
||||
# Initialize heads for different functionalities
|
||||
hidden_size = getattr(self.config, 'hidden_size', 768)
|
||||
|
||||
if self.use_dense:
|
||||
# Dense retrieval doesn't need additional layers (uses CLS token)
|
||||
self.dense_linear = nn.Identity()
|
||||
|
||||
if self.use_sparse:
|
||||
# Sparse retrieval head (SPLADE-style)
|
||||
self.sparse_linear = nn.Linear(hidden_size, 1)
|
||||
self.sparse_activation = nn.ReLU()
|
||||
|
||||
if self.use_colbert:
|
||||
# ColBERT projection layer
|
||||
colbert_dim = colbert_dim if colbert_dim > 0 else hidden_size
|
||||
self.colbert_linear = nn.Linear(hidden_size, colbert_dim)
|
||||
|
||||
# Set device with enhanced detection
|
||||
if device:
|
||||
self.device = torch.device(device)
|
||||
else:
|
||||
self.device = self._detect_device()
|
||||
|
||||
try:
|
||||
self.to(self.device)
|
||||
logger.info(f"BGE-M3 model moved to device: {self.device}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to move model to {self.device}: {e}, using CPU")
|
||||
self.device = torch.device('cpu')
|
||||
self.to(self.device)
|
||||
|
||||
# Cache for tokenizer
|
||||
self._tokenizer = None
|
||||
|
||||
def enable_gradient_checkpointing(self):
|
||||
"""Enable gradient checkpointing for memory efficiency"""
|
||||
if hasattr(self.model, 'gradient_checkpointing_enable'):
|
||||
self.model.gradient_checkpointing_enable()
|
||||
logger.info("Gradient checkpointing enabled for BGE-M3")
|
||||
else:
|
||||
logger.warning("Model does not support gradient checkpointing")
|
||||
|
||||
def disable_gradient_checkpointing(self):
|
||||
"""Disable gradient checkpointing"""
|
||||
if hasattr(self.model, 'gradient_checkpointing_disable'):
|
||||
self.model.gradient_checkpointing_disable()
|
||||
logger.info("Gradient checkpointing disabled for BGE-M3")
|
||||
|
||||
def _detect_device(self) -> torch.device:
|
||||
"""Enhanced device detection with proper error handling"""
|
||||
# Check Ascend NPU first
|
||||
try:
|
||||
import torch_npu
|
||||
if hasattr(torch, 'npu') and torch.npu.is_available(): # type: ignore[attr-defined]
|
||||
logger.info("Using Ascend NPU")
|
||||
return torch.device("npu")
|
||||
except ImportError:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.debug(f"NPU detection failed: {e}")
|
||||
pass
|
||||
|
||||
# Check CUDA
|
||||
if torch.cuda.is_available():
|
||||
logger.info("Using CUDA GPU")
|
||||
return torch.device("cuda")
|
||||
|
||||
# Check MPS (Apple Silicon)
|
||||
try:
|
||||
if hasattr(torch, 'backends') and hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): # type: ignore[attr-defined]
|
||||
logger.info("Using Apple MPS")
|
||||
return torch.device("mps")
|
||||
except Exception as e:
|
||||
logger.debug(f"MPS detection failed: {e}")
|
||||
pass
|
||||
|
||||
# Default to CPU
|
||||
logger.info("Using CPU")
|
||||
return torch.device("cpu")
|
||||
|
||||
def _get_tokenizer(self):
|
||||
"""Lazy load tokenizer with error handling"""
|
||||
if self._tokenizer is None:
|
||||
try:
|
||||
self._tokenizer = AutoTokenizer.from_pretrained(
|
||||
self.model_name_or_path,
|
||||
trust_remote_code=True
|
||||
)
|
||||
logger.info("Tokenizer loaded successfully")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load tokenizer: {e}")
|
||||
raise RuntimeError(f"Tokenizer loading failed: {e}")
|
||||
return self._tokenizer
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
attention_mask: torch.Tensor,
|
||||
token_type_ids: Optional[torch.Tensor] = None,
|
||||
return_dense: bool = True,
|
||||
return_sparse: bool = False,
|
||||
return_colbert: bool = False,
|
||||
return_dict: bool = True
|
||||
) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
|
||||
"""
|
||||
Forward pass for BGE-M3 with dense, sparse (SPLADE+), and ColBERT heads.
|
||||
Args:
|
||||
input_ids: Token IDs [batch_size, seq_length]
|
||||
attention_mask: Attention mask [batch_size, seq_length]
|
||||
token_type_ids: Token type IDs (optional)
|
||||
return_dense: Return dense embeddings
|
||||
return_sparse: Return sparse embeddings
|
||||
return_colbert: Return ColBERT embeddings
|
||||
return_dict: Return dictionary of outputs
|
||||
Returns:
|
||||
Embeddings or dictionary of embeddings
|
||||
"""
|
||||
# Validate inputs
|
||||
if input_ids is None or attention_mask is None:
|
||||
raise ValueError("input_ids and attention_mask are required")
|
||||
if input_ids.shape != attention_mask.shape:
|
||||
raise ValueError("input_ids and attention_mask must have the same shape")
|
||||
|
||||
# Determine what to return based on model configuration and parameters
|
||||
return_dense = return_dense and self.use_dense
|
||||
return_sparse = return_sparse and self.use_sparse
|
||||
return_colbert = return_colbert and self.use_colbert
|
||||
|
||||
if not (return_dense or return_sparse or return_colbert):
|
||||
logger.warning("No output head selected in forward. At least one of dense, sparse, or colbert must be True.")
|
||||
return_dense = True # Default to dense if nothing selected
|
||||
|
||||
try:
|
||||
outputs = self.model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
return_dict=True
|
||||
)
|
||||
last_hidden_state = outputs.last_hidden_state # [batch_size, seq_length, hidden_size]
|
||||
except Exception as e:
|
||||
logger.error(f"Model forward pass failed: {e}")
|
||||
raise RuntimeError(f"Forward pass failed: {e}")
|
||||
|
||||
results = {}
|
||||
|
||||
# Dense embeddings (CLS or mean pooling)
|
||||
if return_dense:
|
||||
if self.pooling_method == "cls" or self.sentence_pooling_method == "cls":
|
||||
dense_vecs = last_hidden_state[:, 0]
|
||||
elif self.pooling_method == "mean" or self.sentence_pooling_method == "mean":
|
||||
expanded_mask = attention_mask.unsqueeze(-1).expand(last_hidden_state.shape)
|
||||
masked_embeddings = last_hidden_state * expanded_mask
|
||||
sum_embeddings = torch.sum(masked_embeddings, dim=1)
|
||||
sum_mask = torch.clamp(attention_mask.sum(dim=1, keepdim=True), min=1e-9)
|
||||
dense_vecs = sum_embeddings / sum_mask
|
||||
else:
|
||||
raise ValueError(f"Unknown pooling method: {self.pooling_method}")
|
||||
|
||||
dense_vecs = self.dense_linear(dense_vecs)
|
||||
# Ensure embeddings are float type
|
||||
dense_vecs = dense_vecs.float()
|
||||
if self.normalize_embeddings:
|
||||
dense_vecs = F.normalize(dense_vecs, p=2, dim=-1)
|
||||
results['dense'] = dense_vecs
|
||||
|
||||
# SPLADE+ sparse head
|
||||
if return_sparse:
|
||||
# SPLADE+ uses log1p(ReLU(.))
|
||||
sparse_logits = self.sparse_linear(last_hidden_state).squeeze(-1) # [batch, seq_len]
|
||||
sparse_activations = torch.log1p(F.relu(sparse_logits))
|
||||
|
||||
# Mask out padding tokens
|
||||
sparse_activations = sparse_activations * attention_mask.float()
|
||||
|
||||
# Optionally normalize sparse activations
|
||||
if self.normalize_embeddings:
|
||||
sparse_activations = F.normalize(sparse_activations, p=2, dim=-1)
|
||||
|
||||
results['sparse'] = sparse_activations
|
||||
|
||||
# FLOPS regularization (for loss, not forward, but add as attribute)
|
||||
self.splade_flops = torch.sum(sparse_activations)
|
||||
|
||||
# ColBERT head with late interaction (maxsim)
|
||||
if return_colbert:
|
||||
colbert_vecs = self.colbert_linear(last_hidden_state) # [batch, seq_len, colbert_dim]
|
||||
|
||||
# Mask out padding tokens
|
||||
expanded_mask = attention_mask.unsqueeze(-1).expand(colbert_vecs.shape)
|
||||
colbert_vecs = colbert_vecs * expanded_mask
|
||||
|
||||
if self.normalize_embeddings:
|
||||
colbert_vecs = F.normalize(colbert_vecs, p=2, dim=-1)
|
||||
|
||||
results['colbert'] = colbert_vecs
|
||||
|
||||
if return_dict:
|
||||
return results
|
||||
|
||||
# If only one head requested, return that
|
||||
if return_dense and len(results) == 1:
|
||||
return results['dense']
|
||||
if return_sparse and len(results) == 1:
|
||||
return results['sparse']
|
||||
if return_colbert and len(results) == 1:
|
||||
return results['colbert']
|
||||
|
||||
return results
|
||||
|
||||
def encode(
|
||||
self,
|
||||
sentences: Union[str, List[str]],
|
||||
batch_size: int = 32,
|
||||
max_length: int = 512,
|
||||
convert_to_numpy: bool = True,
|
||||
convert_to_tensor: bool = False,
|
||||
show_progress_bar: bool = False,
|
||||
normalize_embeddings: bool = True,
|
||||
return_dense: bool = True,
|
||||
return_sparse: bool = False,
|
||||
return_colbert: bool = False,
|
||||
device: Optional[str] = None
|
||||
) -> Union[np.ndarray, torch.Tensor, Dict[str, Union[np.ndarray, torch.Tensor]]]:
|
||||
"""
|
||||
Complete encode implementation with comprehensive error handling
|
||||
|
||||
Args:
|
||||
sentences: Input sentences
|
||||
batch_size: Batch size for encoding
|
||||
max_length: Maximum sequence length
|
||||
convert_to_numpy: Convert to numpy arrays
|
||||
convert_to_tensor: Keep as tensors
|
||||
show_progress_bar: Show progress bar
|
||||
normalize_embeddings: Normalize embeddings
|
||||
return_dense/sparse/colbert: Which embeddings to return
|
||||
device: Device to use
|
||||
|
||||
Returns:
|
||||
Embeddings
|
||||
"""
|
||||
# Input validation
|
||||
if isinstance(sentences, str):
|
||||
sentences = [sentences]
|
||||
|
||||
if not sentences:
|
||||
raise ValueError("Empty sentences list provided")
|
||||
|
||||
if batch_size <= 0:
|
||||
raise ValueError("batch_size must be positive")
|
||||
|
||||
if max_length <= 0:
|
||||
raise ValueError("max_length must be positive")
|
||||
|
||||
# Get tokenizer
|
||||
try:
|
||||
tokenizer = self._get_tokenizer()
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to load tokenizer: {e}")
|
||||
|
||||
# Determine target device
|
||||
target_device = device if device else self.device
|
||||
|
||||
# Initialize storage for embeddings
|
||||
all_embeddings = {
|
||||
'dense': [] if return_dense else None,
|
||||
'sparse': [] if return_sparse else None,
|
||||
'colbert': [] if return_colbert else None
|
||||
}
|
||||
|
||||
# Set model to evaluation mode
|
||||
self.eval()
|
||||
|
||||
try:
|
||||
# Import tqdm if show_progress_bar is True
|
||||
if show_progress_bar:
|
||||
try:
|
||||
from tqdm import tqdm
|
||||
iterator = tqdm(
|
||||
range(0, len(sentences), batch_size),
|
||||
desc="Encoding",
|
||||
total=(len(sentences) + batch_size - 1) // batch_size
|
||||
)
|
||||
except ImportError:
|
||||
logger.warning("tqdm not available, disabling progress bar")
|
||||
iterator = range(0, len(sentences), batch_size)
|
||||
else:
|
||||
iterator = range(0, len(sentences), batch_size)
|
||||
|
||||
# Process in batches
|
||||
for i in iterator:
|
||||
batch_sentences = sentences[i:i + batch_size]
|
||||
|
||||
# Tokenize with proper error handling
|
||||
try:
|
||||
encoded = tokenizer(
|
||||
batch_sentences,
|
||||
padding=True,
|
||||
truncation=True,
|
||||
max_length=max_length,
|
||||
return_tensors='pt'
|
||||
)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Tokenization failed for batch {i//batch_size + 1}: {e}")
|
||||
|
||||
# Move to device
|
||||
try:
|
||||
encoded = {k: v.to(target_device) for k, v in encoded.items()}
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to move batch to device {target_device}: {e}")
|
||||
|
||||
# Get embeddings
|
||||
with torch.no_grad():
|
||||
try:
|
||||
outputs = self.forward(
|
||||
**encoded,
|
||||
return_dense=return_dense,
|
||||
return_sparse=return_sparse,
|
||||
return_colbert=return_colbert,
|
||||
return_dict=True
|
||||
)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Forward pass failed for batch {i//batch_size + 1}: {e}")
|
||||
|
||||
# Collect embeddings
|
||||
for key in ['dense', 'sparse', 'colbert']:
|
||||
if all_embeddings[key] is not None and key in outputs:
|
||||
all_embeddings[key].append(outputs[key].cpu()) # type: ignore[arg-type]
|
||||
|
||||
# Concatenate and convert results
|
||||
final_embeddings = {}
|
||||
for key in ['dense', 'sparse', 'colbert']:
|
||||
if all_embeddings[key] is not None and len(all_embeddings[key]) > 0: # type: ignore[arg-type]
|
||||
try:
|
||||
embeddings = torch.cat(all_embeddings[key], dim=0) # type: ignore[arg-type]
|
||||
|
||||
if normalize_embeddings and key == 'dense':
|
||||
embeddings = F.normalize(embeddings, p=2, dim=-1)
|
||||
|
||||
if convert_to_numpy:
|
||||
embeddings = embeddings.numpy()
|
||||
elif not convert_to_tensor:
|
||||
embeddings = embeddings
|
||||
|
||||
final_embeddings[key] = embeddings
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to process {key} embeddings: {e}")
|
||||
raise
|
||||
|
||||
# Return format
|
||||
if len(final_embeddings) == 1:
|
||||
return list(final_embeddings.values())[0]
|
||||
|
||||
if not final_embeddings:
|
||||
raise RuntimeError("No embeddings were generated")
|
||||
|
||||
return final_embeddings
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Encoding failed: {e}")
|
||||
raise
|
||||
|
||||
def save_pretrained(self, save_path: str):
|
||||
"""Save model to directory with error handling, including custom heads and config."""
|
||||
try:
|
||||
import os
|
||||
import json
|
||||
os.makedirs(save_path, exist_ok=True)
|
||||
|
||||
# Save the base model
|
||||
self.model.save_pretrained(save_path)
|
||||
|
||||
# Save additional configuration (including [m3] section)
|
||||
from utils.config_loader import get_config
|
||||
config_dict = get_config().get_section('m3')
|
||||
config = {
|
||||
'use_dense': self.use_dense,
|
||||
'use_sparse': self.use_sparse,
|
||||
'use_colbert': self.use_colbert,
|
||||
'pooling_method': self.pooling_method,
|
||||
'normalize_embeddings': self.normalize_embeddings,
|
||||
'colbert_dim': self.colbert_dim,
|
||||
'sparse_top_k': self.sparse_top_k,
|
||||
'model_name_or_path': self.model_name_or_path,
|
||||
'm3_config': config_dict
|
||||
}
|
||||
with open(f"{save_path}/m3_config.json", 'w') as f:
|
||||
json.dump(config, f, indent=2)
|
||||
|
||||
# Save custom heads if needed
|
||||
if self.use_sparse:
|
||||
torch.save(self.sparse_linear.state_dict(), os.path.join(save_path, 'sparse_linear.pt'))
|
||||
if self.use_colbert:
|
||||
torch.save(self.colbert_linear.state_dict(), os.path.join(save_path, 'colbert_linear.pt'))
|
||||
|
||||
# Save tokenizer if available
|
||||
if self._tokenizer is not None:
|
||||
self._tokenizer.save_pretrained(save_path)
|
||||
|
||||
logger.info(f"Model and config saved to {save_path}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save model: {e}")
|
||||
raise
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, model_path: str, **kwargs):
|
||||
"""Load model from directory, restoring custom heads and config."""
|
||||
import os
|
||||
import json
|
||||
try:
|
||||
# Load config
|
||||
config_path = os.path.join(model_path, "m3_config.json")
|
||||
if os.path.exists(config_path):
|
||||
with open(config_path, 'r') as f:
|
||||
config = json.load(f)
|
||||
else:
|
||||
config = {}
|
||||
|
||||
# Instantiate model
|
||||
model = cls(
|
||||
model_name_or_path=config.get('model_name_or_path', model_path),
|
||||
use_dense=config.get('use_dense', True),
|
||||
use_sparse=config.get('use_sparse', True),
|
||||
use_colbert=config.get('use_colbert', True),
|
||||
pooling_method=config.get('pooling_method', 'cls'),
|
||||
normalize_embeddings=config.get('normalize_embeddings', True),
|
||||
colbert_dim=config.get('colbert_dim', -1),
|
||||
sparse_top_k=config.get('sparse_top_k', 100),
|
||||
**kwargs
|
||||
)
|
||||
|
||||
# Load base model weights
|
||||
model.model = model.model.from_pretrained(model_path)
|
||||
|
||||
# Load custom heads if present
|
||||
if model.use_sparse:
|
||||
sparse_path = os.path.join(model_path, 'sparse_linear.pt')
|
||||
if os.path.exists(sparse_path):
|
||||
model.sparse_linear.load_state_dict(torch.load(sparse_path, map_location='cpu'))
|
||||
if model.use_colbert:
|
||||
colbert_path = os.path.join(model_path, 'colbert_linear.pt')
|
||||
if os.path.exists(colbert_path):
|
||||
model.colbert_linear.load_state_dict(torch.load(colbert_path, map_location='cpu'))
|
||||
|
||||
logger.info(f"Loaded BGE-M3 model from {model_path}")
|
||||
return model
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load model from {model_path}: {e}")
|
||||
raise
|
||||
|
||||
def compute_similarity(
|
||||
self,
|
||||
queries: Union[torch.Tensor, Dict[str, torch.Tensor]],
|
||||
passages: Union[torch.Tensor, Dict[str, torch.Tensor]],
|
||||
mode: str = "dense"
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Compute similarity scores between queries and passages with error handling
|
||||
|
||||
Args:
|
||||
queries: Query embeddings
|
||||
passages: Passage embeddings
|
||||
mode: Similarity mode ('dense', 'sparse', 'colbert')
|
||||
|
||||
Returns:
|
||||
Similarity scores [num_queries, num_passages]
|
||||
"""
|
||||
try:
|
||||
# Extract embeddings based on mode
|
||||
if isinstance(queries, dict):
|
||||
if mode not in queries:
|
||||
raise ValueError(f"Mode '{mode}' not found in query embeddings")
|
||||
query_emb = queries[mode]
|
||||
else:
|
||||
query_emb = queries
|
||||
|
||||
if isinstance(passages, dict):
|
||||
if mode not in passages:
|
||||
raise ValueError(f"Mode '{mode}' not found in passage embeddings")
|
||||
passage_emb = passages[mode]
|
||||
else:
|
||||
passage_emb = passages
|
||||
|
||||
# Validate tensor shapes
|
||||
if query_emb.dim() < 2 or passage_emb.dim() < 2:
|
||||
raise ValueError("Embeddings must be at least 2-dimensional")
|
||||
|
||||
if mode == "dense":
|
||||
# Simple dot product for dense
|
||||
return torch.matmul(query_emb, passage_emb.t())
|
||||
|
||||
elif mode == "sparse":
|
||||
# Dot product for sparse (simplified)
|
||||
return torch.matmul(query_emb, passage_emb.t())
|
||||
|
||||
elif mode == "colbert":
|
||||
# MaxSim for ColBERT
|
||||
if query_emb.dim() != 3 or passage_emb.dim() != 3:
|
||||
raise ValueError("ColBERT embeddings must be 3-dimensional [batch, seq_len, dim]")
|
||||
|
||||
scores = []
|
||||
for q in query_emb:
|
||||
# q: [query_len, dim]
|
||||
# passage_emb: [num_passages, passage_len, dim]
|
||||
passage_scores = []
|
||||
for p in passage_emb:
|
||||
# Compute similarity matrix
|
||||
sim_matrix = torch.matmul(q, p.t()) # [query_len, passage_len]
|
||||
# Max over passage tokens for each query token
|
||||
max_sim_scores = sim_matrix.max(dim=1)[0] # [query_len]
|
||||
# Sum over query tokens
|
||||
score = max_sim_scores.sum()
|
||||
passage_scores.append(score)
|
||||
scores.append(torch.stack(passage_scores))
|
||||
return torch.stack(scores)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unknown similarity mode: {mode}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Similarity computation failed: {e}")
|
||||
raise
|
||||
|
||||
def get_sentence_embedding_dimension(self) -> int:
|
||||
"""Get the dimension of sentence embeddings"""
|
||||
return self.config.hidden_size # type: ignore[attr-defined]
|
||||
|
||||
def get_word_embedding_dimension(self) -> int:
|
||||
"""Get the dimension of word embeddings"""
|
||||
return self.config.hidden_size # type: ignore[attr-defined]
|
||||
Reference in New Issue
Block a user