484 lines
17 KiB
Python
484 lines
17 KiB
Python
"""
|
|
BGE-Reranker cross-encoder model wrapper
|
|
"""
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from transformers import AutoModel, AutoTokenizer, AutoConfig, AutoModelForSequenceClassification
|
|
from typing import Dict, List, Optional, Union, Tuple
|
|
import numpy as np
|
|
import logging
|
|
from dataclasses import dataclass
|
|
from utils.config_loader import get_config
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
@dataclass
|
|
class RerankerOutput:
|
|
"""Output class for reranker predictions"""
|
|
scores: torch.Tensor
|
|
logits: torch.Tensor
|
|
embeddings: Optional[torch.Tensor] = None
|
|
|
|
|
|
def load_reranker_model_and_tokenizer(model_name_or_path, cache_dir=None, num_labels=1):
|
|
"""Load reranker model and tokenizer from HuggingFace or ModelScope based on config.toml"""
|
|
config = get_config()
|
|
|
|
# Try to get model source from config, with fallback to huggingface
|
|
try:
|
|
model_source = config.get('model_paths', 'source', 'huggingface')
|
|
except Exception:
|
|
# Fallback if config loading fails
|
|
logger.warning("Failed to load model source from config, defaulting to 'huggingface'")
|
|
model_source = 'huggingface'
|
|
|
|
logger.info(f"Loading reranker model from source: {model_source}")
|
|
|
|
if model_source == 'huggingface':
|
|
from transformers import AutoModel, AutoModelForSequenceClassification, AutoTokenizer, AutoConfig
|
|
model_config = AutoConfig.from_pretrained(
|
|
model_name_or_path,
|
|
num_labels=num_labels,
|
|
cache_dir=cache_dir
|
|
)
|
|
try:
|
|
model = AutoModelForSequenceClassification.from_pretrained(
|
|
model_name_or_path,
|
|
config=model_config,
|
|
cache_dir=cache_dir
|
|
)
|
|
use_base_model = False
|
|
except:
|
|
model = AutoModel.from_pretrained(
|
|
model_name_or_path,
|
|
config=model_config,
|
|
cache_dir=cache_dir
|
|
)
|
|
use_base_model = True
|
|
tokenizer = AutoTokenizer.from_pretrained(
|
|
model_name_or_path,
|
|
cache_dir=cache_dir
|
|
)
|
|
return model, tokenizer, model_config, use_base_model
|
|
elif model_source == 'modelscope':
|
|
try:
|
|
from modelscope.models import Model as MSModel
|
|
from modelscope.hub.snapshot_download import snapshot_download as ms_snapshot_download
|
|
ms_snapshot_download(model_id=model_name_or_path, cache_dir=cache_dir)
|
|
model = MSModel.from_pretrained(model_name_or_path, cache_dir=cache_dir)
|
|
tokenizer = None # ModelScope tokenizer loading can be added if needed
|
|
model_config = None
|
|
use_base_model = False
|
|
return model, tokenizer, model_config, use_base_model
|
|
except ImportError:
|
|
logger.error("modelscope not available. Install it first: pip install modelscope")
|
|
raise
|
|
except Exception as e:
|
|
logger.error(f"Failed to load reranker model from ModelScope: {e}")
|
|
raise
|
|
else:
|
|
logger.error(f"Unknown model source: {model_source}. Supported: 'huggingface', 'modelscope'.")
|
|
raise RuntimeError(f"Unknown model source: {model_source}")
|
|
|
|
|
|
class BGERerankerModel(nn.Module):
|
|
"""
|
|
BGE-Reranker cross-encoder model for passage reranking
|
|
|
|
Notes:
|
|
- All parameters (model_name_or_path, num_labels, cache_dir, use_fp16, gradient_checkpointing, device, etc.) should be passed from config-driven training scripts.
|
|
- This class does not load config directly; pass config values from your main script for consistency.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
model_name_or_path: str = "BAAI/bge-reranker-base",
|
|
num_labels: int = 1,
|
|
cache_dir: Optional[str] = None,
|
|
use_fp16: bool = False,
|
|
gradient_checkpointing: bool = False,
|
|
device: Optional[str] = None
|
|
):
|
|
"""
|
|
Initialize BGE-Reranker model.
|
|
Args:
|
|
model_name_or_path: Path or identifier for pretrained model
|
|
num_labels: Number of output labels
|
|
cache_dir: Directory for model cache
|
|
use_fp16: Use mixed precision
|
|
gradient_checkpointing: Enable gradient checkpointing
|
|
device: Device to use
|
|
"""
|
|
super().__init__()
|
|
self.model_name_or_path = model_name_or_path
|
|
self.num_labels = num_labels
|
|
self.use_fp16 = use_fp16
|
|
# Load configuration and model
|
|
try:
|
|
logger.info(f"Loading BGE-Reranker model from {model_name_or_path} using source from config.toml...")
|
|
self.model, self.tokenizer, self.config, self.use_base_model = load_reranker_model_and_tokenizer(
|
|
model_name_or_path,
|
|
cache_dir=cache_dir,
|
|
num_labels=num_labels
|
|
)
|
|
logger.info("BGE-Reranker model loaded successfully")
|
|
except Exception as e:
|
|
logger.error(f"Failed to load BGE-Reranker model: {e}")
|
|
raise RuntimeError(f"Model initialization failed: {e}")
|
|
# Add classification head
|
|
if self.use_base_model:
|
|
self.classifier = nn.Linear(self.config.hidden_size, num_labels) # type: ignore[attr-defined]
|
|
self.dropout = nn.Dropout(
|
|
self.config.hidden_dropout_prob if hasattr(self.config, 'hidden_dropout_prob') else 0.1) # type: ignore[attr-defined]
|
|
# Enable gradient checkpointing
|
|
if gradient_checkpointing:
|
|
try:
|
|
self.model.gradient_checkpointing_enable()
|
|
except Exception:
|
|
logger.warning("Gradient checkpointing requested but not supported by this model.")
|
|
# Set device
|
|
if device:
|
|
try:
|
|
self.device = torch.device(device)
|
|
self.to(self.device)
|
|
except Exception as e:
|
|
logger.warning(f"Failed to move model to {device}: {e}, using CPU")
|
|
self.device = torch.device('cpu')
|
|
self.to(self.device)
|
|
else:
|
|
self.device = next(self.parameters()).device
|
|
# Mixed precision
|
|
if self.use_fp16:
|
|
try:
|
|
self.half()
|
|
except Exception as e:
|
|
logger.warning(f"Failed to set model to half precision: {e}")
|
|
|
|
def forward(
|
|
self,
|
|
input_ids: torch.Tensor,
|
|
attention_mask: torch.Tensor,
|
|
token_type_ids: Optional[torch.Tensor] = None,
|
|
labels: Optional[torch.Tensor] = None,
|
|
return_dict: bool = True
|
|
) -> Union[torch.Tensor, RerankerOutput]:
|
|
"""
|
|
Forward pass for cross-encoder reranking
|
|
|
|
Args:
|
|
input_ids: Token IDs [batch_size, seq_length]
|
|
attention_mask: Attention mask [batch_size, seq_length]
|
|
token_type_ids: Token type IDs (optional)
|
|
labels: Labels for training (optional)
|
|
return_dict: Return RerankerOutput object
|
|
|
|
Returns:
|
|
Scores or RerankerOutput
|
|
"""
|
|
if self.use_base_model:
|
|
# Use base model + custom head
|
|
outputs = self.model(
|
|
input_ids=input_ids,
|
|
attention_mask=attention_mask,
|
|
token_type_ids=token_type_ids,
|
|
return_dict=True
|
|
)
|
|
|
|
# Get pooled output (CLS token)
|
|
if hasattr(outputs, 'pooler_output') and outputs.pooler_output is not None:
|
|
pooled_output = outputs.pooler_output
|
|
else:
|
|
# Manual pooling if pooler not available
|
|
last_hidden_state = outputs.last_hidden_state
|
|
pooled_output = last_hidden_state[:, 0]
|
|
|
|
# Apply dropout and classifier
|
|
pooled_output = self.dropout(pooled_output)
|
|
logits = self.classifier(pooled_output)
|
|
|
|
embeddings = pooled_output
|
|
else:
|
|
# Use sequence classification model
|
|
outputs = self.model(
|
|
input_ids=input_ids,
|
|
attention_mask=attention_mask,
|
|
token_type_ids=token_type_ids,
|
|
labels=labels,
|
|
return_dict=True
|
|
)
|
|
logits = outputs.logits
|
|
embeddings = None
|
|
|
|
# Convert logits to scores (sigmoid for binary classification)
|
|
if self.num_labels == 1:
|
|
scores = torch.sigmoid(logits.squeeze(-1))
|
|
else:
|
|
scores = F.softmax(logits, dim=-1)
|
|
|
|
if not return_dict:
|
|
return scores
|
|
|
|
return RerankerOutput(
|
|
scores=scores,
|
|
logits=logits,
|
|
embeddings=embeddings
|
|
)
|
|
|
|
def predict(
|
|
self,
|
|
queries: Union[str, List[str]],
|
|
passages: Union[str, List[str], List[List[str]]],
|
|
batch_size: int = 32,
|
|
max_length: int = 512,
|
|
show_progress_bar: bool = False,
|
|
convert_to_numpy: bool = True,
|
|
normalize_scores: bool = False
|
|
) -> Union[np.ndarray, torch.Tensor, List[float]]:
|
|
"""
|
|
Predict reranking scores for query-passage pairs
|
|
|
|
Args:
|
|
queries: Query or list of queries
|
|
passages: Passage(s) or list of passages per query
|
|
batch_size: Batch size for inference
|
|
max_length: Maximum sequence length
|
|
show_progress_bar: Show progress bar
|
|
convert_to_numpy: Convert to numpy array
|
|
normalize_scores: Normalize scores to [0, 1]
|
|
|
|
Returns:
|
|
Reranking scores
|
|
"""
|
|
# Handle input formats
|
|
if isinstance(queries, str):
|
|
queries = [queries]
|
|
|
|
if isinstance(passages, str):
|
|
passages = [[passages]]
|
|
elif isinstance(passages[0], str):
|
|
passages = [passages] # type: ignore[assignment]
|
|
|
|
# Ensure queries and passages match
|
|
if len(queries) == 1 and len(passages) > 1:
|
|
queries = queries * len(passages)
|
|
|
|
# Create pairs
|
|
pairs = []
|
|
for query, passage_list in zip(queries, passages):
|
|
if isinstance(passage_list, str):
|
|
passage_list = [passage_list]
|
|
for passage in passage_list:
|
|
pairs.append((query, passage))
|
|
|
|
# Load tokenizer
|
|
tokenizer = AutoTokenizer.from_pretrained(self.model_name_or_path)
|
|
|
|
# Score in batches
|
|
all_scores = []
|
|
|
|
if show_progress_bar:
|
|
from tqdm import tqdm
|
|
pbar = tqdm(total=len(pairs), desc="Scoring")
|
|
|
|
for i in range(0, len(pairs), batch_size):
|
|
batch_pairs = pairs[i:i + batch_size]
|
|
|
|
# Tokenize
|
|
batch_queries = [pair[0] for pair in batch_pairs]
|
|
batch_passages = [pair[1] for pair in batch_pairs]
|
|
|
|
encoded = tokenizer(
|
|
batch_queries,
|
|
batch_passages,
|
|
padding=True,
|
|
truncation=True,
|
|
max_length=max_length,
|
|
return_tensors='pt'
|
|
)
|
|
|
|
# Move to device
|
|
encoded = {k: v.to(self.device) for k, v in encoded.items()}
|
|
|
|
# Get scores
|
|
with torch.no_grad():
|
|
outputs = self.forward(**encoded, return_dict=True)
|
|
scores = outputs.scores # type: ignore[attr-defined]
|
|
|
|
all_scores.append(scores)
|
|
|
|
if show_progress_bar:
|
|
pbar.update(len(batch_pairs))
|
|
|
|
if show_progress_bar:
|
|
pbar.close()
|
|
|
|
# Concatenate scores
|
|
all_scores = torch.cat(all_scores, dim=0)
|
|
|
|
# Normalize if requested
|
|
if normalize_scores:
|
|
all_scores = (all_scores - all_scores.min()) / (all_scores.max() - all_scores.min() + 1e-8)
|
|
|
|
# Convert format
|
|
if convert_to_numpy:
|
|
all_scores = all_scores.cpu().numpy()
|
|
elif not isinstance(all_scores, torch.Tensor):
|
|
all_scores = all_scores.cpu()
|
|
|
|
# Reshape if single query
|
|
if len(queries) == 1 and len(passages) == 1:
|
|
return all_scores.item() if convert_to_numpy else all_scores # type: ignore[attr-defined]
|
|
|
|
return all_scores
|
|
|
|
def rerank(
|
|
self,
|
|
query: str,
|
|
passages: List[str],
|
|
top_k: Optional[int] = None,
|
|
return_scores: bool = False,
|
|
batch_size: int = 32,
|
|
max_length: int = 512
|
|
) -> Union[List[str], Tuple[List[str], List[float]]]:
|
|
"""
|
|
Rerank passages for a query
|
|
|
|
Args:
|
|
query: Query string
|
|
passages: List of passages to rerank
|
|
top_k: Return only top-k passages
|
|
return_scores: Also return scores
|
|
batch_size: Batch size for scoring
|
|
max_length: Maximum sequence length
|
|
|
|
Returns:
|
|
Reranked passages (and optionally scores)
|
|
"""
|
|
# Get scores
|
|
scores = self.predict(
|
|
query,
|
|
passages,
|
|
batch_size=batch_size,
|
|
max_length=max_length,
|
|
convert_to_numpy=True
|
|
)
|
|
|
|
# Sort by scores (descending)
|
|
sorted_indices = np.argsort(scores)[::-1]
|
|
|
|
# Apply top_k if specified
|
|
if top_k is not None:
|
|
sorted_indices = sorted_indices[:top_k]
|
|
|
|
# Get reranked passages
|
|
reranked_passages = [passages[i] for i in sorted_indices]
|
|
|
|
if return_scores:
|
|
reranked_scores = [float(scores[i]) for i in sorted_indices]
|
|
return reranked_passages, reranked_scores
|
|
|
|
return reranked_passages
|
|
|
|
def save_pretrained(self, save_path: str):
|
|
"""Save model to directory, including custom heads and config."""
|
|
try:
|
|
import os
|
|
import json
|
|
os.makedirs(save_path, exist_ok=True)
|
|
|
|
# Save the base model
|
|
self.model.save_pretrained(save_path)
|
|
|
|
# Save additional configuration (including [reranker] section)
|
|
from utils.config_loader import get_config
|
|
config_dict = get_config().get_section('reranker')
|
|
config = {
|
|
'use_fp16': self.use_fp16 if hasattr(self, 'use_fp16') else False,
|
|
'num_labels': getattr(self, 'num_labels', 1),
|
|
'reranker_config': config_dict
|
|
}
|
|
with open(f"{save_path}/reranker_config.json", 'w') as f:
|
|
json.dump(config, f, indent=2)
|
|
|
|
# Save custom heads if present
|
|
if hasattr(self, 'classifier'):
|
|
torch.save(self.classifier.state_dict(), os.path.join(save_path, 'classifier.pt'))
|
|
if hasattr(self, 'dropout'):
|
|
torch.save(self.dropout.state_dict(), os.path.join(save_path, 'dropout.pt'))
|
|
|
|
# Save tokenizer if available
|
|
if hasattr(self, '_tokenizer') and self._tokenizer is not None:
|
|
self._tokenizer.save_pretrained(save_path) # type: ignore[attr-defined]
|
|
|
|
logger.info(f"Reranker model and config saved to {save_path}")
|
|
except Exception as e:
|
|
logger.error(f"Failed to save reranker model: {e}")
|
|
raise
|
|
|
|
@classmethod
|
|
def from_pretrained(cls, model_path: str, **kwargs):
|
|
"""Load model from directory, restoring custom heads and config."""
|
|
import os
|
|
import json
|
|
try:
|
|
# Load config
|
|
config_path = os.path.join(model_path, "reranker_config.json")
|
|
if os.path.exists(config_path):
|
|
with open(config_path, 'r') as f:
|
|
config = json.load(f)
|
|
else:
|
|
config = {}
|
|
|
|
# Instantiate model
|
|
model = cls(
|
|
model_name_or_path=model_path,
|
|
num_labels=config.get('num_labels', 1),
|
|
use_fp16=config.get('use_fp16', False),
|
|
**kwargs
|
|
)
|
|
|
|
# Load base model weights
|
|
model.model = model.model.from_pretrained(model_path)
|
|
|
|
# Load custom heads if present
|
|
classifier_path = os.path.join(model_path, 'classifier.pt')
|
|
if hasattr(model, 'classifier') and os.path.exists(classifier_path):
|
|
model.classifier.load_state_dict(torch.load(classifier_path, map_location='cpu'))
|
|
dropout_path = os.path.join(model_path, 'dropout.pt')
|
|
if hasattr(model, 'dropout') and os.path.exists(dropout_path):
|
|
model.dropout.load_state_dict(torch.load(dropout_path, map_location='cpu'))
|
|
|
|
logger.info(f"Loaded BGE-Reranker model from {model_path}")
|
|
return model
|
|
except Exception as e:
|
|
logger.error(f"Failed to load reranker model from {model_path}: {e}")
|
|
raise
|
|
|
|
def compute_loss(
|
|
self,
|
|
scores: torch.Tensor,
|
|
labels: torch.Tensor,
|
|
loss_type: str = "binary_cross_entropy"
|
|
) -> torch.Tensor:
|
|
"""
|
|
Compute loss for training
|
|
|
|
Args:
|
|
scores: Model scores
|
|
labels: Ground truth labels
|
|
loss_type: Type of loss function
|
|
|
|
Returns:
|
|
Loss value
|
|
"""
|
|
if loss_type == "binary_cross_entropy":
|
|
return F.binary_cross_entropy(scores, labels.float())
|
|
elif loss_type == "cross_entropy":
|
|
return F.cross_entropy(scores, labels.long())
|
|
elif loss_type == "mse":
|
|
return F.mse_loss(scores, labels.float())
|
|
else:
|
|
raise ValueError(f"Unknown loss type: {loss_type}")
|