bge_finetune/models/bge_reranker.py
2025-07-22 16:55:25 +08:00

484 lines
17 KiB
Python

"""
BGE-Reranker cross-encoder model wrapper
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoModel, AutoTokenizer, AutoConfig, AutoModelForSequenceClassification
from typing import Dict, List, Optional, Union, Tuple
import numpy as np
import logging
from dataclasses import dataclass
from utils.config_loader import get_config
logger = logging.getLogger(__name__)
@dataclass
class RerankerOutput:
"""Output class for reranker predictions"""
scores: torch.Tensor
logits: torch.Tensor
embeddings: Optional[torch.Tensor] = None
def load_reranker_model_and_tokenizer(model_name_or_path, cache_dir=None, num_labels=1):
"""Load reranker model and tokenizer from HuggingFace or ModelScope based on config.toml"""
config = get_config()
# Try to get model source from config, with fallback to huggingface
try:
model_source = config.get('model_paths', 'source', 'huggingface')
except Exception:
# Fallback if config loading fails
logger.warning("Failed to load model source from config, defaulting to 'huggingface'")
model_source = 'huggingface'
logger.info(f"Loading reranker model from source: {model_source}")
if model_source == 'huggingface':
from transformers import AutoModel, AutoModelForSequenceClassification, AutoTokenizer, AutoConfig
model_config = AutoConfig.from_pretrained(
model_name_or_path,
num_labels=num_labels,
cache_dir=cache_dir
)
try:
model = AutoModelForSequenceClassification.from_pretrained(
model_name_or_path,
config=model_config,
cache_dir=cache_dir
)
use_base_model = False
except:
model = AutoModel.from_pretrained(
model_name_or_path,
config=model_config,
cache_dir=cache_dir
)
use_base_model = True
tokenizer = AutoTokenizer.from_pretrained(
model_name_or_path,
cache_dir=cache_dir
)
return model, tokenizer, model_config, use_base_model
elif model_source == 'modelscope':
try:
from modelscope.models import Model as MSModel
from modelscope.hub.snapshot_download import snapshot_download as ms_snapshot_download
ms_snapshot_download(model_id=model_name_or_path, cache_dir=cache_dir)
model = MSModel.from_pretrained(model_name_or_path, cache_dir=cache_dir)
tokenizer = None # ModelScope tokenizer loading can be added if needed
model_config = None
use_base_model = False
return model, tokenizer, model_config, use_base_model
except ImportError:
logger.error("modelscope not available. Install it first: pip install modelscope")
raise
except Exception as e:
logger.error(f"Failed to load reranker model from ModelScope: {e}")
raise
else:
logger.error(f"Unknown model source: {model_source}. Supported: 'huggingface', 'modelscope'.")
raise RuntimeError(f"Unknown model source: {model_source}")
class BGERerankerModel(nn.Module):
"""
BGE-Reranker cross-encoder model for passage reranking
Notes:
- All parameters (model_name_or_path, num_labels, cache_dir, use_fp16, gradient_checkpointing, device, etc.) should be passed from config-driven training scripts.
- This class does not load config directly; pass config values from your main script for consistency.
"""
def __init__(
self,
model_name_or_path: str = "BAAI/bge-reranker-base",
num_labels: int = 1,
cache_dir: Optional[str] = None,
use_fp16: bool = False,
gradient_checkpointing: bool = False,
device: Optional[str] = None
):
"""
Initialize BGE-Reranker model.
Args:
model_name_or_path: Path or identifier for pretrained model
num_labels: Number of output labels
cache_dir: Directory for model cache
use_fp16: Use mixed precision
gradient_checkpointing: Enable gradient checkpointing
device: Device to use
"""
super().__init__()
self.model_name_or_path = model_name_or_path
self.num_labels = num_labels
self.use_fp16 = use_fp16
# Load configuration and model
try:
logger.info(f"Loading BGE-Reranker model from {model_name_or_path} using source from config.toml...")
self.model, self.tokenizer, self.config, self.use_base_model = load_reranker_model_and_tokenizer(
model_name_or_path,
cache_dir=cache_dir,
num_labels=num_labels
)
logger.info("BGE-Reranker model loaded successfully")
except Exception as e:
logger.error(f"Failed to load BGE-Reranker model: {e}")
raise RuntimeError(f"Model initialization failed: {e}")
# Add classification head
if self.use_base_model:
self.classifier = nn.Linear(self.config.hidden_size, num_labels) # type: ignore[attr-defined]
self.dropout = nn.Dropout(
self.config.hidden_dropout_prob if hasattr(self.config, 'hidden_dropout_prob') else 0.1) # type: ignore[attr-defined]
# Enable gradient checkpointing
if gradient_checkpointing:
try:
self.model.gradient_checkpointing_enable()
except Exception:
logger.warning("Gradient checkpointing requested but not supported by this model.")
# Set device
if device:
try:
self.device = torch.device(device)
self.to(self.device)
except Exception as e:
logger.warning(f"Failed to move model to {device}: {e}, using CPU")
self.device = torch.device('cpu')
self.to(self.device)
else:
self.device = next(self.parameters()).device
# Mixed precision
if self.use_fp16:
try:
self.half()
except Exception as e:
logger.warning(f"Failed to set model to half precision: {e}")
def forward(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
token_type_ids: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
return_dict: bool = True
) -> Union[torch.Tensor, RerankerOutput]:
"""
Forward pass for cross-encoder reranking
Args:
input_ids: Token IDs [batch_size, seq_length]
attention_mask: Attention mask [batch_size, seq_length]
token_type_ids: Token type IDs (optional)
labels: Labels for training (optional)
return_dict: Return RerankerOutput object
Returns:
Scores or RerankerOutput
"""
if self.use_base_model:
# Use base model + custom head
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
return_dict=True
)
# Get pooled output (CLS token)
if hasattr(outputs, 'pooler_output') and outputs.pooler_output is not None:
pooled_output = outputs.pooler_output
else:
# Manual pooling if pooler not available
last_hidden_state = outputs.last_hidden_state
pooled_output = last_hidden_state[:, 0]
# Apply dropout and classifier
pooled_output = self.dropout(pooled_output)
logits = self.classifier(pooled_output)
embeddings = pooled_output
else:
# Use sequence classification model
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
labels=labels,
return_dict=True
)
logits = outputs.logits
embeddings = None
# Convert logits to scores (sigmoid for binary classification)
if self.num_labels == 1:
scores = torch.sigmoid(logits.squeeze(-1))
else:
scores = F.softmax(logits, dim=-1)
if not return_dict:
return scores
return RerankerOutput(
scores=scores,
logits=logits,
embeddings=embeddings
)
def predict(
self,
queries: Union[str, List[str]],
passages: Union[str, List[str], List[List[str]]],
batch_size: int = 32,
max_length: int = 512,
show_progress_bar: bool = False,
convert_to_numpy: bool = True,
normalize_scores: bool = False
) -> Union[np.ndarray, torch.Tensor, List[float]]:
"""
Predict reranking scores for query-passage pairs
Args:
queries: Query or list of queries
passages: Passage(s) or list of passages per query
batch_size: Batch size for inference
max_length: Maximum sequence length
show_progress_bar: Show progress bar
convert_to_numpy: Convert to numpy array
normalize_scores: Normalize scores to [0, 1]
Returns:
Reranking scores
"""
# Handle input formats
if isinstance(queries, str):
queries = [queries]
if isinstance(passages, str):
passages = [[passages]]
elif isinstance(passages[0], str):
passages = [passages] # type: ignore[assignment]
# Ensure queries and passages match
if len(queries) == 1 and len(passages) > 1:
queries = queries * len(passages)
# Create pairs
pairs = []
for query, passage_list in zip(queries, passages):
if isinstance(passage_list, str):
passage_list = [passage_list]
for passage in passage_list:
pairs.append((query, passage))
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(self.model_name_or_path)
# Score in batches
all_scores = []
if show_progress_bar:
from tqdm import tqdm
pbar = tqdm(total=len(pairs), desc="Scoring")
for i in range(0, len(pairs), batch_size):
batch_pairs = pairs[i:i + batch_size]
# Tokenize
batch_queries = [pair[0] for pair in batch_pairs]
batch_passages = [pair[1] for pair in batch_pairs]
encoded = tokenizer(
batch_queries,
batch_passages,
padding=True,
truncation=True,
max_length=max_length,
return_tensors='pt'
)
# Move to device
encoded = {k: v.to(self.device) for k, v in encoded.items()}
# Get scores
with torch.no_grad():
outputs = self.forward(**encoded, return_dict=True)
scores = outputs.scores # type: ignore[attr-defined]
all_scores.append(scores)
if show_progress_bar:
pbar.update(len(batch_pairs))
if show_progress_bar:
pbar.close()
# Concatenate scores
all_scores = torch.cat(all_scores, dim=0)
# Normalize if requested
if normalize_scores:
all_scores = (all_scores - all_scores.min()) / (all_scores.max() - all_scores.min() + 1e-8)
# Convert format
if convert_to_numpy:
all_scores = all_scores.cpu().numpy()
elif not isinstance(all_scores, torch.Tensor):
all_scores = all_scores.cpu()
# Reshape if single query
if len(queries) == 1 and len(passages) == 1:
return all_scores.item() if convert_to_numpy else all_scores # type: ignore[attr-defined]
return all_scores
def rerank(
self,
query: str,
passages: List[str],
top_k: Optional[int] = None,
return_scores: bool = False,
batch_size: int = 32,
max_length: int = 512
) -> Union[List[str], Tuple[List[str], List[float]]]:
"""
Rerank passages for a query
Args:
query: Query string
passages: List of passages to rerank
top_k: Return only top-k passages
return_scores: Also return scores
batch_size: Batch size for scoring
max_length: Maximum sequence length
Returns:
Reranked passages (and optionally scores)
"""
# Get scores
scores = self.predict(
query,
passages,
batch_size=batch_size,
max_length=max_length,
convert_to_numpy=True
)
# Sort by scores (descending)
sorted_indices = np.argsort(scores)[::-1]
# Apply top_k if specified
if top_k is not None:
sorted_indices = sorted_indices[:top_k]
# Get reranked passages
reranked_passages = [passages[i] for i in sorted_indices]
if return_scores:
reranked_scores = [float(scores[i]) for i in sorted_indices]
return reranked_passages, reranked_scores
return reranked_passages
def save_pretrained(self, save_path: str):
"""Save model to directory, including custom heads and config."""
try:
import os
import json
os.makedirs(save_path, exist_ok=True)
# Save the base model
self.model.save_pretrained(save_path)
# Save additional configuration (including [reranker] section)
from utils.config_loader import get_config
config_dict = get_config().get_section('reranker')
config = {
'use_fp16': self.use_fp16 if hasattr(self, 'use_fp16') else False,
'num_labels': getattr(self, 'num_labels', 1),
'reranker_config': config_dict
}
with open(f"{save_path}/reranker_config.json", 'w') as f:
json.dump(config, f, indent=2)
# Save custom heads if present
if hasattr(self, 'classifier'):
torch.save(self.classifier.state_dict(), os.path.join(save_path, 'classifier.pt'))
if hasattr(self, 'dropout'):
torch.save(self.dropout.state_dict(), os.path.join(save_path, 'dropout.pt'))
# Save tokenizer if available
if hasattr(self, '_tokenizer') and self._tokenizer is not None:
self._tokenizer.save_pretrained(save_path) # type: ignore[attr-defined]
logger.info(f"Reranker model and config saved to {save_path}")
except Exception as e:
logger.error(f"Failed to save reranker model: {e}")
raise
@classmethod
def from_pretrained(cls, model_path: str, **kwargs):
"""Load model from directory, restoring custom heads and config."""
import os
import json
try:
# Load config
config_path = os.path.join(model_path, "reranker_config.json")
if os.path.exists(config_path):
with open(config_path, 'r') as f:
config = json.load(f)
else:
config = {}
# Instantiate model
model = cls(
model_name_or_path=model_path,
num_labels=config.get('num_labels', 1),
use_fp16=config.get('use_fp16', False),
**kwargs
)
# Load base model weights
model.model = model.model.from_pretrained(model_path)
# Load custom heads if present
classifier_path = os.path.join(model_path, 'classifier.pt')
if hasattr(model, 'classifier') and os.path.exists(classifier_path):
model.classifier.load_state_dict(torch.load(classifier_path, map_location='cpu'))
dropout_path = os.path.join(model_path, 'dropout.pt')
if hasattr(model, 'dropout') and os.path.exists(dropout_path):
model.dropout.load_state_dict(torch.load(dropout_path, map_location='cpu'))
logger.info(f"Loaded BGE-Reranker model from {model_path}")
return model
except Exception as e:
logger.error(f"Failed to load reranker model from {model_path}: {e}")
raise
def compute_loss(
self,
scores: torch.Tensor,
labels: torch.Tensor,
loss_type: str = "binary_cross_entropy"
) -> torch.Tensor:
"""
Compute loss for training
Args:
scores: Model scores
labels: Ground truth labels
loss_type: Type of loss function
Returns:
Loss value
"""
if loss_type == "binary_cross_entropy":
return F.binary_cross_entropy(scores, labels.float())
elif loss_type == "cross_entropy":
return F.cross_entropy(scores, labels.long())
elif loss_type == "mse":
return F.mse_loss(scores, labels.float())
else:
raise ValueError(f"Unknown loss type: {loss_type}")