""" BGE-Reranker cross-encoder model wrapper """ import torch import torch.nn as nn import torch.nn.functional as F from transformers import AutoModel, AutoTokenizer, AutoConfig, AutoModelForSequenceClassification from typing import Dict, List, Optional, Union, Tuple import numpy as np import logging from dataclasses import dataclass from utils.config_loader import get_config logger = logging.getLogger(__name__) @dataclass class RerankerOutput: """Output class for reranker predictions""" scores: torch.Tensor logits: torch.Tensor embeddings: Optional[torch.Tensor] = None def load_reranker_model_and_tokenizer(model_name_or_path, cache_dir=None, num_labels=1): """Load reranker model and tokenizer from HuggingFace or ModelScope based on config.toml""" config = get_config() # Try to get model source from config, with fallback to huggingface try: model_source = config.get('model_paths', 'source', 'huggingface') except Exception: # Fallback if config loading fails logger.warning("Failed to load model source from config, defaulting to 'huggingface'") model_source = 'huggingface' logger.info(f"Loading reranker model from source: {model_source}") if model_source == 'huggingface': from transformers import AutoModel, AutoModelForSequenceClassification, AutoTokenizer, AutoConfig model_config = AutoConfig.from_pretrained( model_name_or_path, num_labels=num_labels, cache_dir=cache_dir ) try: model = AutoModelForSequenceClassification.from_pretrained( model_name_or_path, config=model_config, cache_dir=cache_dir ) use_base_model = False except: model = AutoModel.from_pretrained( model_name_or_path, config=model_config, cache_dir=cache_dir ) use_base_model = True tokenizer = AutoTokenizer.from_pretrained( model_name_or_path, cache_dir=cache_dir ) return model, tokenizer, model_config, use_base_model elif model_source == 'modelscope': try: from modelscope.models import Model as MSModel from modelscope.hub.snapshot_download import snapshot_download as ms_snapshot_download ms_snapshot_download(model_id=model_name_or_path, cache_dir=cache_dir) model = MSModel.from_pretrained(model_name_or_path, cache_dir=cache_dir) tokenizer = None # ModelScope tokenizer loading can be added if needed model_config = None use_base_model = False return model, tokenizer, model_config, use_base_model except ImportError: logger.error("modelscope not available. Install it first: pip install modelscope") raise except Exception as e: logger.error(f"Failed to load reranker model from ModelScope: {e}") raise else: logger.error(f"Unknown model source: {model_source}. Supported: 'huggingface', 'modelscope'.") raise RuntimeError(f"Unknown model source: {model_source}") class BGERerankerModel(nn.Module): """ BGE-Reranker cross-encoder model for passage reranking Notes: - All parameters (model_name_or_path, num_labels, cache_dir, use_fp16, gradient_checkpointing, device, etc.) should be passed from config-driven training scripts. - This class does not load config directly; pass config values from your main script for consistency. """ def __init__( self, model_name_or_path: str = "BAAI/bge-reranker-base", num_labels: int = 1, cache_dir: Optional[str] = None, use_fp16: bool = False, gradient_checkpointing: bool = False, device: Optional[str] = None ): """ Initialize BGE-Reranker model. Args: model_name_or_path: Path or identifier for pretrained model num_labels: Number of output labels cache_dir: Directory for model cache use_fp16: Use mixed precision gradient_checkpointing: Enable gradient checkpointing device: Device to use """ super().__init__() self.model_name_or_path = model_name_or_path self.num_labels = num_labels self.use_fp16 = use_fp16 # Load configuration and model try: logger.info(f"Loading BGE-Reranker model from {model_name_or_path} using source from config.toml...") self.model, self.tokenizer, self.config, self.use_base_model = load_reranker_model_and_tokenizer( model_name_or_path, cache_dir=cache_dir, num_labels=num_labels ) logger.info("BGE-Reranker model loaded successfully") except Exception as e: logger.error(f"Failed to load BGE-Reranker model: {e}") raise RuntimeError(f"Model initialization failed: {e}") # Add classification head if self.use_base_model: self.classifier = nn.Linear(self.config.hidden_size, num_labels) # type: ignore[attr-defined] self.dropout = nn.Dropout( self.config.hidden_dropout_prob if hasattr(self.config, 'hidden_dropout_prob') else 0.1) # type: ignore[attr-defined] # Enable gradient checkpointing if gradient_checkpointing: try: self.model.gradient_checkpointing_enable() except Exception: logger.warning("Gradient checkpointing requested but not supported by this model.") # Set device if device: try: self.device = torch.device(device) self.to(self.device) except Exception as e: logger.warning(f"Failed to move model to {device}: {e}, using CPU") self.device = torch.device('cpu') self.to(self.device) else: self.device = next(self.parameters()).device # Mixed precision if self.use_fp16: try: self.half() except Exception as e: logger.warning(f"Failed to set model to half precision: {e}") def forward( self, input_ids: torch.Tensor, attention_mask: torch.Tensor, token_type_ids: Optional[torch.Tensor] = None, labels: Optional[torch.Tensor] = None, return_dict: bool = True ) -> Union[torch.Tensor, RerankerOutput]: """ Forward pass for cross-encoder reranking Args: input_ids: Token IDs [batch_size, seq_length] attention_mask: Attention mask [batch_size, seq_length] token_type_ids: Token type IDs (optional) labels: Labels for training (optional) return_dict: Return RerankerOutput object Returns: Scores or RerankerOutput """ if self.use_base_model: # Use base model + custom head outputs = self.model( input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, return_dict=True ) # Get pooled output (CLS token) if hasattr(outputs, 'pooler_output') and outputs.pooler_output is not None: pooled_output = outputs.pooler_output else: # Manual pooling if pooler not available last_hidden_state = outputs.last_hidden_state pooled_output = last_hidden_state[:, 0] # Apply dropout and classifier pooled_output = self.dropout(pooled_output) logits = self.classifier(pooled_output) embeddings = pooled_output else: # Use sequence classification model outputs = self.model( input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, labels=labels, return_dict=True ) logits = outputs.logits embeddings = None # Convert logits to scores (sigmoid for binary classification) if self.num_labels == 1: scores = torch.sigmoid(logits.squeeze(-1)) else: scores = F.softmax(logits, dim=-1) if not return_dict: return scores return RerankerOutput( scores=scores, logits=logits, embeddings=embeddings ) def predict( self, queries: Union[str, List[str]], passages: Union[str, List[str], List[List[str]]], batch_size: int = 32, max_length: int = 512, show_progress_bar: bool = False, convert_to_numpy: bool = True, normalize_scores: bool = False ) -> Union[np.ndarray, torch.Tensor, List[float]]: """ Predict reranking scores for query-passage pairs Args: queries: Query or list of queries passages: Passage(s) or list of passages per query batch_size: Batch size for inference max_length: Maximum sequence length show_progress_bar: Show progress bar convert_to_numpy: Convert to numpy array normalize_scores: Normalize scores to [0, 1] Returns: Reranking scores """ # Handle input formats if isinstance(queries, str): queries = [queries] if isinstance(passages, str): passages = [[passages]] elif isinstance(passages[0], str): passages = [passages] # type: ignore[assignment] # Ensure queries and passages match if len(queries) == 1 and len(passages) > 1: queries = queries * len(passages) # Create pairs pairs = [] for query, passage_list in zip(queries, passages): if isinstance(passage_list, str): passage_list = [passage_list] for passage in passage_list: pairs.append((query, passage)) # Load tokenizer tokenizer = AutoTokenizer.from_pretrained(self.model_name_or_path) # Score in batches all_scores = [] if show_progress_bar: from tqdm import tqdm pbar = tqdm(total=len(pairs), desc="Scoring") for i in range(0, len(pairs), batch_size): batch_pairs = pairs[i:i + batch_size] # Tokenize batch_queries = [pair[0] for pair in batch_pairs] batch_passages = [pair[1] for pair in batch_pairs] encoded = tokenizer( batch_queries, batch_passages, padding=True, truncation=True, max_length=max_length, return_tensors='pt' ) # Move to device encoded = {k: v.to(self.device) for k, v in encoded.items()} # Get scores with torch.no_grad(): outputs = self.forward(**encoded, return_dict=True) scores = outputs.scores # type: ignore[attr-defined] all_scores.append(scores) if show_progress_bar: pbar.update(len(batch_pairs)) if show_progress_bar: pbar.close() # Concatenate scores all_scores = torch.cat(all_scores, dim=0) # Normalize if requested if normalize_scores: all_scores = (all_scores - all_scores.min()) / (all_scores.max() - all_scores.min() + 1e-8) # Convert format if convert_to_numpy: all_scores = all_scores.cpu().numpy() elif not isinstance(all_scores, torch.Tensor): all_scores = all_scores.cpu() # Reshape if single query if len(queries) == 1 and len(passages) == 1: return all_scores.item() if convert_to_numpy else all_scores # type: ignore[attr-defined] return all_scores def rerank( self, query: str, passages: List[str], top_k: Optional[int] = None, return_scores: bool = False, batch_size: int = 32, max_length: int = 512 ) -> Union[List[str], Tuple[List[str], List[float]]]: """ Rerank passages for a query Args: query: Query string passages: List of passages to rerank top_k: Return only top-k passages return_scores: Also return scores batch_size: Batch size for scoring max_length: Maximum sequence length Returns: Reranked passages (and optionally scores) """ # Get scores scores = self.predict( query, passages, batch_size=batch_size, max_length=max_length, convert_to_numpy=True ) # Sort by scores (descending) sorted_indices = np.argsort(scores)[::-1] # Apply top_k if specified if top_k is not None: sorted_indices = sorted_indices[:top_k] # Get reranked passages reranked_passages = [passages[i] for i in sorted_indices] if return_scores: reranked_scores = [float(scores[i]) for i in sorted_indices] return reranked_passages, reranked_scores return reranked_passages def save_pretrained(self, save_path: str): """Save model to directory, including custom heads and config.""" try: import os import json os.makedirs(save_path, exist_ok=True) # Save the base model self.model.save_pretrained(save_path) # Save additional configuration (including [reranker] section) from utils.config_loader import get_config config_dict = get_config().get_section('reranker') config = { 'use_fp16': self.use_fp16 if hasattr(self, 'use_fp16') else False, 'num_labels': getattr(self, 'num_labels', 1), 'reranker_config': config_dict } with open(f"{save_path}/reranker_config.json", 'w') as f: json.dump(config, f, indent=2) # Save custom heads if present if hasattr(self, 'classifier'): torch.save(self.classifier.state_dict(), os.path.join(save_path, 'classifier.pt')) if hasattr(self, 'dropout'): torch.save(self.dropout.state_dict(), os.path.join(save_path, 'dropout.pt')) # Save tokenizer if available if hasattr(self, '_tokenizer') and self._tokenizer is not None: self._tokenizer.save_pretrained(save_path) # type: ignore[attr-defined] logger.info(f"Reranker model and config saved to {save_path}") except Exception as e: logger.error(f"Failed to save reranker model: {e}") raise @classmethod def from_pretrained(cls, model_path: str, **kwargs): """Load model from directory, restoring custom heads and config.""" import os import json try: # Load config config_path = os.path.join(model_path, "reranker_config.json") if os.path.exists(config_path): with open(config_path, 'r') as f: config = json.load(f) else: config = {} # Instantiate model model = cls( model_name_or_path=model_path, num_labels=config.get('num_labels', 1), use_fp16=config.get('use_fp16', False), **kwargs ) # Load base model weights model.model = model.model.from_pretrained(model_path) # Load custom heads if present classifier_path = os.path.join(model_path, 'classifier.pt') if hasattr(model, 'classifier') and os.path.exists(classifier_path): model.classifier.load_state_dict(torch.load(classifier_path, map_location='cpu')) dropout_path = os.path.join(model_path, 'dropout.pt') if hasattr(model, 'dropout') and os.path.exists(dropout_path): model.dropout.load_state_dict(torch.load(dropout_path, map_location='cpu')) logger.info(f"Loaded BGE-Reranker model from {model_path}") return model except Exception as e: logger.error(f"Failed to load reranker model from {model_path}: {e}") raise def compute_loss( self, scores: torch.Tensor, labels: torch.Tensor, loss_type: str = "binary_cross_entropy" ) -> torch.Tensor: """ Compute loss for training Args: scores: Model scores labels: Ground truth labels loss_type: Type of loss function Returns: Loss value """ if loss_type == "binary_cross_entropy": return F.binary_cross_entropy(scores, labels.float()) elif loss_type == "cross_entropy": return F.cross_entropy(scores, labels.long()) elif loss_type == "mse": return F.mse_loss(scores, labels.float()) else: raise ValueError(f"Unknown loss type: {loss_type}")