""" BGE-M3 model wrapper for multi-functionality embedding """ import torch import torch.nn as nn import torch.nn.functional as F from transformers import AutoModel, AutoTokenizer, AutoConfig from typing import Dict, List, Optional, Union, Tuple import numpy as np import logging from utils.config_loader import get_config import os logger = logging.getLogger(__name__) def load_model_and_tokenizer(model_name_or_path, cache_dir=None): """ Load model and tokenizer with comprehensive error handling and fallback strategies Args: model_name_or_path: Model identifier or local path cache_dir: Cache directory for downloaded models Returns: Tuple of (model, tokenizer, config) Raises: RuntimeError: If model loading fails after all attempts """ config = get_config() model_source = config.get('model_paths', 'source', 'huggingface') # Track loading attempts for debugging attempts = [] # Try primary loading method try: if model_source == 'huggingface': return _load_from_huggingface(model_name_or_path, cache_dir) elif model_source == 'modelscope': return _load_from_modelscope(model_name_or_path, cache_dir) else: raise ValueError(f"Unknown model source: {model_source}") except Exception as e: attempts.append(f"{model_source}: {str(e)}") logger.warning(f"Primary loading from {model_source} failed: {e}") # Fallback strategies fallback_sources = ['huggingface'] if model_source == 'modelscope' else ['modelscope'] for fallback in fallback_sources: try: logger.info(f"Attempting fallback loading from {fallback}") if fallback == 'huggingface': return _load_from_huggingface(model_name_or_path, cache_dir) elif fallback == 'modelscope': return _load_from_modelscope(model_name_or_path, cache_dir) except Exception as e: attempts.append(f"{fallback}: {str(e)}") logger.warning(f"Fallback loading from {fallback} failed: {e}") # Try loading from local path if it exists if os.path.exists(model_name_or_path): try: logger.info(f"Attempting to load from local path: {model_name_or_path}") return _load_from_local(model_name_or_path, cache_dir) except Exception as e: attempts.append(f"local: {str(e)}") logger.warning(f"Local loading failed: {e}") # All attempts failed error_msg = f"Failed to load model from any source. Attempts:\n" + "\n".join(attempts) logger.error(error_msg) raise RuntimeError(error_msg) def _load_from_huggingface(model_name_or_path, cache_dir=None): """Load model from HuggingFace with error handling""" try: from transformers import AutoModel, AutoTokenizer, AutoConfig # Load config first to validate model model_config = AutoConfig.from_pretrained( model_name_or_path, cache_dir=cache_dir, trust_remote_code=True ) # Load model model = AutoModel.from_pretrained( model_name_or_path, config=model_config, cache_dir=cache_dir, trust_remote_code=True ) # Load tokenizer tokenizer = AutoTokenizer.from_pretrained( model_name_or_path, cache_dir=cache_dir, trust_remote_code=True ) logger.info(f"Successfully loaded model from HuggingFace: {model_name_or_path}") return model, tokenizer, model_config except ImportError as e: raise ImportError(f"Transformers library not available: {e}") except Exception as e: raise RuntimeError(f"HuggingFace loading error: {e}") def _load_from_modelscope(model_name_or_path, cache_dir=None): """Load model from ModelScope with error handling""" try: from modelscope.models import Model as MSModel from modelscope.hub.snapshot_download import snapshot_download as ms_snapshot_download # Download model if needed local_path = ms_snapshot_download(model_id=model_name_or_path, cache_dir=cache_dir) # Load model model = MSModel.from_pretrained(local_path) # Try to load tokenizer tokenizer = None try: from transformers import AutoTokenizer tokenizer = AutoTokenizer.from_pretrained(local_path) except Exception as e: logger.warning(f"Could not load tokenizer from ModelScope: {e}") # Try to get config config = None config_path = os.path.join(local_path, 'config.json') if os.path.exists(config_path): import json with open(config_path, 'r') as f: config = json.load(f) logger.info(f"Successfully loaded model from ModelScope: {model_name_or_path}") return model, tokenizer, config except ImportError as e: raise ImportError(f"ModelScope library not available: {e}") except Exception as e: raise RuntimeError(f"ModelScope loading error: {e}") def _load_from_local(model_path, cache_dir=None): """Load model from local directory with error handling""" try: from transformers import AutoModel, AutoTokenizer, AutoConfig if not os.path.isdir(model_path): raise ValueError(f"Local path is not a directory: {model_path}") # Check for required files required_files = ['config.json'] missing_files = [f for f in required_files if not os.path.exists(os.path.join(model_path, f))] if missing_files: raise FileNotFoundError(f"Missing required files: {missing_files}") # Load components model_config = AutoConfig.from_pretrained(model_path, trust_remote_code=True) model = AutoModel.from_pretrained(model_path, config=model_config, trust_remote_code=True) tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) logger.info(f"Successfully loaded model from local path: {model_path}") return model, tokenizer, model_config except Exception as e: raise RuntimeError(f"Local loading error: {e}") class BGEM3Model(nn.Module): """ BGE-M3 model wrapper supporting dense, sparse, and multi-vector retrieval Notes: - All parameters (model_name_or_path, use_dense, use_sparse, use_colbert, pooling_method, etc.) should be passed from config-driven training scripts. - This class does not load config directly; pass config values from your main script for consistency. """ def __init__( self, model_name_or_path: str = "BAAI/bge-m3", use_dense: bool = True, use_sparse: bool = True, use_colbert: bool = True, pooling_method: str = "cls", normalize_embeddings: bool = True, sentence_pooling_method: str = "cls", colbert_dim: int = -1, sparse_top_k: int = 100, cache_dir: Optional[str] = None, device: Optional[str] = None, gradient_checkpointing: bool = False ): super().__init__() self.model_name_or_path = model_name_or_path self.use_dense = use_dense self.use_sparse = use_sparse self.use_colbert = use_colbert self.pooling_method = pooling_method self.normalize_embeddings = normalize_embeddings self.colbert_dim = colbert_dim self.sparse_top_k = sparse_top_k self.gradient_checkpointing = gradient_checkpointing # Initialize model with comprehensive error handling try: logger.info(f"Loading BGE-M3 model from {model_name_or_path} using source from config.toml...") self.model, self.tokenizer, self.config = load_model_and_tokenizer( model_name_or_path, cache_dir=cache_dir ) logger.info("BGE-M3 model loaded successfully") except Exception as e: logger.error(f"Failed to load BGE-M3 model: {e}") raise RuntimeError(f"Model initialization failed: {e}") # Enable gradient checkpointing if requested if self.gradient_checkpointing: self.enable_gradient_checkpointing() # Ensure we have a valid config if self.config is None: logger.warning("Model config is None, creating a basic config") # Create a basic config object class BasicConfig: def __init__(self): self.hidden_size = 768 # Default BERT hidden size self.config = BasicConfig() # Initialize heads for different functionalities hidden_size = getattr(self.config, 'hidden_size', 768) if self.use_dense: # Dense retrieval doesn't need additional layers (uses CLS token) self.dense_linear = nn.Identity() if self.use_sparse: # Sparse retrieval head (SPLADE-style) self.sparse_linear = nn.Linear(hidden_size, 1) self.sparse_activation = nn.ReLU() if self.use_colbert: # ColBERT projection layer colbert_dim = colbert_dim if colbert_dim > 0 else hidden_size self.colbert_linear = nn.Linear(hidden_size, colbert_dim) # Set device with enhanced detection if device: self.device = torch.device(device) else: self.device = self._detect_device() try: self.to(self.device) logger.info(f"BGE-M3 model moved to device: {self.device}") except Exception as e: logger.warning(f"Failed to move model to {self.device}: {e}, using CPU") self.device = torch.device('cpu') self.to(self.device) # Cache for tokenizer self._tokenizer = None def enable_gradient_checkpointing(self): """Enable gradient checkpointing for memory efficiency""" if hasattr(self.model, 'gradient_checkpointing_enable'): self.model.gradient_checkpointing_enable() logger.info("Gradient checkpointing enabled for BGE-M3") else: logger.warning("Model does not support gradient checkpointing") def disable_gradient_checkpointing(self): """Disable gradient checkpointing""" if hasattr(self.model, 'gradient_checkpointing_disable'): self.model.gradient_checkpointing_disable() logger.info("Gradient checkpointing disabled for BGE-M3") def _detect_device(self) -> torch.device: """Enhanced device detection with proper error handling""" # Check Ascend NPU first try: import torch_npu if hasattr(torch, 'npu') and torch.npu.is_available(): # type: ignore[attr-defined] logger.info("Using Ascend NPU") return torch.device("npu") except ImportError: pass except Exception as e: logger.debug(f"NPU detection failed: {e}") pass # Check CUDA if torch.cuda.is_available(): logger.info("Using CUDA GPU") return torch.device("cuda") # Check MPS (Apple Silicon) try: if hasattr(torch, 'backends') and hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): # type: ignore[attr-defined] logger.info("Using Apple MPS") return torch.device("mps") except Exception as e: logger.debug(f"MPS detection failed: {e}") pass # Default to CPU logger.info("Using CPU") return torch.device("cpu") def _get_tokenizer(self): """Lazy load tokenizer with error handling""" if self._tokenizer is None: try: self._tokenizer = AutoTokenizer.from_pretrained( self.model_name_or_path, trust_remote_code=True ) logger.info("Tokenizer loaded successfully") except Exception as e: logger.error(f"Failed to load tokenizer: {e}") raise RuntimeError(f"Tokenizer loading failed: {e}") return self._tokenizer def forward( self, input_ids: torch.Tensor, attention_mask: torch.Tensor, token_type_ids: Optional[torch.Tensor] = None, return_dense: bool = True, return_sparse: bool = False, return_colbert: bool = False, return_dict: bool = True ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]: """ Forward pass for BGE-M3 with dense, sparse (SPLADE+), and ColBERT heads. Args: input_ids: Token IDs [batch_size, seq_length] attention_mask: Attention mask [batch_size, seq_length] token_type_ids: Token type IDs (optional) return_dense: Return dense embeddings return_sparse: Return sparse embeddings return_colbert: Return ColBERT embeddings return_dict: Return dictionary of outputs Returns: Embeddings or dictionary of embeddings """ # Validate inputs if input_ids is None or attention_mask is None: raise ValueError("input_ids and attention_mask are required") if input_ids.shape != attention_mask.shape: raise ValueError("input_ids and attention_mask must have the same shape") # Determine what to return based on model configuration and parameters return_dense = return_dense and self.use_dense return_sparse = return_sparse and self.use_sparse return_colbert = return_colbert and self.use_colbert if not (return_dense or return_sparse or return_colbert): logger.warning("No output head selected in forward. At least one of dense, sparse, or colbert must be True.") return_dense = True # Default to dense if nothing selected try: outputs = self.model( input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, return_dict=True ) last_hidden_state = outputs.last_hidden_state # [batch_size, seq_length, hidden_size] except Exception as e: logger.error(f"Model forward pass failed: {e}") raise RuntimeError(f"Forward pass failed: {e}") results = {} # Dense embeddings (CLS or mean pooling) if return_dense: if self.pooling_method == "cls" or self.sentence_pooling_method == "cls": dense_vecs = last_hidden_state[:, 0] elif self.pooling_method == "mean" or self.sentence_pooling_method == "mean": expanded_mask = attention_mask.unsqueeze(-1).expand(last_hidden_state.shape) masked_embeddings = last_hidden_state * expanded_mask sum_embeddings = torch.sum(masked_embeddings, dim=1) sum_mask = torch.clamp(attention_mask.sum(dim=1, keepdim=True), min=1e-9) dense_vecs = sum_embeddings / sum_mask else: raise ValueError(f"Unknown pooling method: {self.pooling_method}") dense_vecs = self.dense_linear(dense_vecs) # Ensure embeddings are float type dense_vecs = dense_vecs.float() if self.normalize_embeddings: dense_vecs = F.normalize(dense_vecs, p=2, dim=-1) results['dense'] = dense_vecs # SPLADE+ sparse head if return_sparse: # SPLADE+ uses log1p(ReLU(.)) sparse_logits = self.sparse_linear(last_hidden_state).squeeze(-1) # [batch, seq_len] sparse_activations = torch.log1p(F.relu(sparse_logits)) # Mask out padding tokens sparse_activations = sparse_activations * attention_mask.float() # Optionally normalize sparse activations if self.normalize_embeddings: sparse_activations = F.normalize(sparse_activations, p=2, dim=-1) results['sparse'] = sparse_activations # FLOPS regularization (for loss, not forward, but add as attribute) self.splade_flops = torch.sum(sparse_activations) # ColBERT head with late interaction (maxsim) if return_colbert: colbert_vecs = self.colbert_linear(last_hidden_state) # [batch, seq_len, colbert_dim] # Mask out padding tokens expanded_mask = attention_mask.unsqueeze(-1).expand(colbert_vecs.shape) colbert_vecs = colbert_vecs * expanded_mask if self.normalize_embeddings: colbert_vecs = F.normalize(colbert_vecs, p=2, dim=-1) results['colbert'] = colbert_vecs if return_dict: return results # If only one head requested, return that if return_dense and len(results) == 1: return results['dense'] if return_sparse and len(results) == 1: return results['sparse'] if return_colbert and len(results) == 1: return results['colbert'] return results def encode( self, sentences: Union[str, List[str]], batch_size: int = 32, max_length: int = 512, convert_to_numpy: bool = True, convert_to_tensor: bool = False, show_progress_bar: bool = False, normalize_embeddings: bool = True, return_dense: bool = True, return_sparse: bool = False, return_colbert: bool = False, device: Optional[str] = None ) -> Union[np.ndarray, torch.Tensor, Dict[str, Union[np.ndarray, torch.Tensor]]]: """ Complete encode implementation with comprehensive error handling Args: sentences: Input sentences batch_size: Batch size for encoding max_length: Maximum sequence length convert_to_numpy: Convert to numpy arrays convert_to_tensor: Keep as tensors show_progress_bar: Show progress bar normalize_embeddings: Normalize embeddings return_dense/sparse/colbert: Which embeddings to return device: Device to use Returns: Embeddings """ # Input validation if isinstance(sentences, str): sentences = [sentences] if not sentences: raise ValueError("Empty sentences list provided") if batch_size <= 0: raise ValueError("batch_size must be positive") if max_length <= 0: raise ValueError("max_length must be positive") # Get tokenizer try: tokenizer = self._get_tokenizer() except Exception as e: raise RuntimeError(f"Failed to load tokenizer: {e}") # Determine target device target_device = device if device else self.device # Initialize storage for embeddings all_embeddings = { 'dense': [] if return_dense else None, 'sparse': [] if return_sparse else None, 'colbert': [] if return_colbert else None } # Set model to evaluation mode self.eval() try: # Import tqdm if show_progress_bar is True if show_progress_bar: try: from tqdm import tqdm iterator = tqdm( range(0, len(sentences), batch_size), desc="Encoding", total=(len(sentences) + batch_size - 1) // batch_size ) except ImportError: logger.warning("tqdm not available, disabling progress bar") iterator = range(0, len(sentences), batch_size) else: iterator = range(0, len(sentences), batch_size) # Process in batches for i in iterator: batch_sentences = sentences[i:i + batch_size] # Tokenize with proper error handling try: encoded = tokenizer( batch_sentences, padding=True, truncation=True, max_length=max_length, return_tensors='pt' ) except Exception as e: raise RuntimeError(f"Tokenization failed for batch {i//batch_size + 1}: {e}") # Move to device try: encoded = {k: v.to(target_device) for k, v in encoded.items()} except Exception as e: raise RuntimeError(f"Failed to move batch to device {target_device}: {e}") # Get embeddings with torch.no_grad(): try: outputs = self.forward( **encoded, return_dense=return_dense, return_sparse=return_sparse, return_colbert=return_colbert, return_dict=True ) except Exception as e: raise RuntimeError(f"Forward pass failed for batch {i//batch_size + 1}: {e}") # Collect embeddings for key in ['dense', 'sparse', 'colbert']: if all_embeddings[key] is not None and key in outputs: all_embeddings[key].append(outputs[key].cpu()) # type: ignore[arg-type] # Concatenate and convert results final_embeddings = {} for key in ['dense', 'sparse', 'colbert']: if all_embeddings[key] is not None and len(all_embeddings[key]) > 0: # type: ignore[arg-type] try: embeddings = torch.cat(all_embeddings[key], dim=0) # type: ignore[arg-type] if normalize_embeddings and key == 'dense': embeddings = F.normalize(embeddings, p=2, dim=-1) if convert_to_numpy: embeddings = embeddings.numpy() elif not convert_to_tensor: embeddings = embeddings final_embeddings[key] = embeddings except Exception as e: logger.error(f"Failed to process {key} embeddings: {e}") raise # Return format if len(final_embeddings) == 1: return list(final_embeddings.values())[0] if not final_embeddings: raise RuntimeError("No embeddings were generated") return final_embeddings except Exception as e: logger.error(f"Encoding failed: {e}") raise def save_pretrained(self, save_path: str): """Save model to directory with error handling, including custom heads and config.""" try: import os import json os.makedirs(save_path, exist_ok=True) # Save the base model self.model.save_pretrained(save_path) # Save additional configuration (including [m3] section) from utils.config_loader import get_config config_dict = get_config().get_section('m3') config = { 'use_dense': self.use_dense, 'use_sparse': self.use_sparse, 'use_colbert': self.use_colbert, 'pooling_method': self.pooling_method, 'normalize_embeddings': self.normalize_embeddings, 'colbert_dim': self.colbert_dim, 'sparse_top_k': self.sparse_top_k, 'model_name_or_path': self.model_name_or_path, 'm3_config': config_dict } with open(f"{save_path}/m3_config.json", 'w') as f: json.dump(config, f, indent=2) # Save custom heads if needed if self.use_sparse: torch.save(self.sparse_linear.state_dict(), os.path.join(save_path, 'sparse_linear.pt')) if self.use_colbert: torch.save(self.colbert_linear.state_dict(), os.path.join(save_path, 'colbert_linear.pt')) # Save tokenizer if available if self._tokenizer is not None: self._tokenizer.save_pretrained(save_path) logger.info(f"Model and config saved to {save_path}") except Exception as e: logger.error(f"Failed to save model: {e}") raise @classmethod def from_pretrained(cls, model_path: str, **kwargs): """Load model from directory, restoring custom heads and config.""" import os import json try: # Load config config_path = os.path.join(model_path, "m3_config.json") if os.path.exists(config_path): with open(config_path, 'r') as f: config = json.load(f) else: config = {} # Instantiate model model = cls( model_name_or_path=config.get('model_name_or_path', model_path), use_dense=config.get('use_dense', True), use_sparse=config.get('use_sparse', True), use_colbert=config.get('use_colbert', True), pooling_method=config.get('pooling_method', 'cls'), normalize_embeddings=config.get('normalize_embeddings', True), colbert_dim=config.get('colbert_dim', -1), sparse_top_k=config.get('sparse_top_k', 100), **kwargs ) # Load base model weights model.model = model.model.from_pretrained(model_path) # Load custom heads if present if model.use_sparse: sparse_path = os.path.join(model_path, 'sparse_linear.pt') if os.path.exists(sparse_path): model.sparse_linear.load_state_dict(torch.load(sparse_path, map_location='cpu')) if model.use_colbert: colbert_path = os.path.join(model_path, 'colbert_linear.pt') if os.path.exists(colbert_path): model.colbert_linear.load_state_dict(torch.load(colbert_path, map_location='cpu')) logger.info(f"Loaded BGE-M3 model from {model_path}") return model except Exception as e: logger.error(f"Failed to load model from {model_path}: {e}") raise def compute_similarity( self, queries: Union[torch.Tensor, Dict[str, torch.Tensor]], passages: Union[torch.Tensor, Dict[str, torch.Tensor]], mode: str = "dense" ) -> torch.Tensor: """ Compute similarity scores between queries and passages with error handling Args: queries: Query embeddings passages: Passage embeddings mode: Similarity mode ('dense', 'sparse', 'colbert') Returns: Similarity scores [num_queries, num_passages] """ try: # Extract embeddings based on mode if isinstance(queries, dict): if mode not in queries: raise ValueError(f"Mode '{mode}' not found in query embeddings") query_emb = queries[mode] else: query_emb = queries if isinstance(passages, dict): if mode not in passages: raise ValueError(f"Mode '{mode}' not found in passage embeddings") passage_emb = passages[mode] else: passage_emb = passages # Validate tensor shapes if query_emb.dim() < 2 or passage_emb.dim() < 2: raise ValueError("Embeddings must be at least 2-dimensional") if mode == "dense": # Simple dot product for dense return torch.matmul(query_emb, passage_emb.t()) elif mode == "sparse": # Dot product for sparse (simplified) return torch.matmul(query_emb, passage_emb.t()) elif mode == "colbert": # MaxSim for ColBERT if query_emb.dim() != 3 or passage_emb.dim() != 3: raise ValueError("ColBERT embeddings must be 3-dimensional [batch, seq_len, dim]") scores = [] for q in query_emb: # q: [query_len, dim] # passage_emb: [num_passages, passage_len, dim] passage_scores = [] for p in passage_emb: # Compute similarity matrix sim_matrix = torch.matmul(q, p.t()) # [query_len, passage_len] # Max over passage tokens for each query token max_sim_scores = sim_matrix.max(dim=1)[0] # [query_len] # Sum over query tokens score = max_sim_scores.sum() passage_scores.append(score) scores.append(torch.stack(passage_scores)) return torch.stack(scores) else: raise ValueError(f"Unknown similarity mode: {mode}") except Exception as e: logger.error(f"Similarity computation failed: {e}") raise def get_sentence_embedding_dimension(self) -> int: """Get the dimension of sentence embeddings""" return self.config.hidden_size # type: ignore[attr-defined] def get_word_embedding_dimension(self) -> int: """Get the dimension of word embeddings""" return self.config.hidden_size # type: ignore[attr-defined]