init commit
This commit is contained in:
483
models/bge_reranker.py
Normal file
483
models/bge_reranker.py
Normal file
@@ -0,0 +1,483 @@
|
||||
"""
|
||||
BGE-Reranker cross-encoder model wrapper
|
||||
"""
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from transformers import AutoModel, AutoTokenizer, AutoConfig, AutoModelForSequenceClassification
|
||||
from typing import Dict, List, Optional, Union, Tuple
|
||||
import numpy as np
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from utils.config_loader import get_config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class RerankerOutput:
|
||||
"""Output class for reranker predictions"""
|
||||
scores: torch.Tensor
|
||||
logits: torch.Tensor
|
||||
embeddings: Optional[torch.Tensor] = None
|
||||
|
||||
|
||||
def load_reranker_model_and_tokenizer(model_name_or_path, cache_dir=None, num_labels=1):
|
||||
"""Load reranker model and tokenizer from HuggingFace or ModelScope based on config.toml"""
|
||||
config = get_config()
|
||||
|
||||
# Try to get model source from config, with fallback to huggingface
|
||||
try:
|
||||
model_source = config.get('model_paths', 'source', 'huggingface')
|
||||
except Exception:
|
||||
# Fallback if config loading fails
|
||||
logger.warning("Failed to load model source from config, defaulting to 'huggingface'")
|
||||
model_source = 'huggingface'
|
||||
|
||||
logger.info(f"Loading reranker model from source: {model_source}")
|
||||
|
||||
if model_source == 'huggingface':
|
||||
from transformers import AutoModel, AutoModelForSequenceClassification, AutoTokenizer, AutoConfig
|
||||
model_config = AutoConfig.from_pretrained(
|
||||
model_name_or_path,
|
||||
num_labels=num_labels,
|
||||
cache_dir=cache_dir
|
||||
)
|
||||
try:
|
||||
model = AutoModelForSequenceClassification.from_pretrained(
|
||||
model_name_or_path,
|
||||
config=model_config,
|
||||
cache_dir=cache_dir
|
||||
)
|
||||
use_base_model = False
|
||||
except:
|
||||
model = AutoModel.from_pretrained(
|
||||
model_name_or_path,
|
||||
config=model_config,
|
||||
cache_dir=cache_dir
|
||||
)
|
||||
use_base_model = True
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_name_or_path,
|
||||
cache_dir=cache_dir
|
||||
)
|
||||
return model, tokenizer, model_config, use_base_model
|
||||
elif model_source == 'modelscope':
|
||||
try:
|
||||
from modelscope.models import Model as MSModel
|
||||
from modelscope.hub.snapshot_download import snapshot_download as ms_snapshot_download
|
||||
ms_snapshot_download(model_id=model_name_or_path, cache_dir=cache_dir)
|
||||
model = MSModel.from_pretrained(model_name_or_path, cache_dir=cache_dir)
|
||||
tokenizer = None # ModelScope tokenizer loading can be added if needed
|
||||
model_config = None
|
||||
use_base_model = False
|
||||
return model, tokenizer, model_config, use_base_model
|
||||
except ImportError:
|
||||
logger.error("modelscope not available. Install it first: pip install modelscope")
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load reranker model from ModelScope: {e}")
|
||||
raise
|
||||
else:
|
||||
logger.error(f"Unknown model source: {model_source}. Supported: 'huggingface', 'modelscope'.")
|
||||
raise RuntimeError(f"Unknown model source: {model_source}")
|
||||
|
||||
|
||||
class BGERerankerModel(nn.Module):
|
||||
"""
|
||||
BGE-Reranker cross-encoder model for passage reranking
|
||||
|
||||
Notes:
|
||||
- All parameters (model_name_or_path, num_labels, cache_dir, use_fp16, gradient_checkpointing, device, etc.) should be passed from config-driven training scripts.
|
||||
- This class does not load config directly; pass config values from your main script for consistency.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name_or_path: str = "BAAI/bge-reranker-base",
|
||||
num_labels: int = 1,
|
||||
cache_dir: Optional[str] = None,
|
||||
use_fp16: bool = False,
|
||||
gradient_checkpointing: bool = False,
|
||||
device: Optional[str] = None
|
||||
):
|
||||
"""
|
||||
Initialize BGE-Reranker model.
|
||||
Args:
|
||||
model_name_or_path: Path or identifier for pretrained model
|
||||
num_labels: Number of output labels
|
||||
cache_dir: Directory for model cache
|
||||
use_fp16: Use mixed precision
|
||||
gradient_checkpointing: Enable gradient checkpointing
|
||||
device: Device to use
|
||||
"""
|
||||
super().__init__()
|
||||
self.model_name_or_path = model_name_or_path
|
||||
self.num_labels = num_labels
|
||||
self.use_fp16 = use_fp16
|
||||
# Load configuration and model
|
||||
try:
|
||||
logger.info(f"Loading BGE-Reranker model from {model_name_or_path} using source from config.toml...")
|
||||
self.model, self.tokenizer, self.config, self.use_base_model = load_reranker_model_and_tokenizer(
|
||||
model_name_or_path,
|
||||
cache_dir=cache_dir,
|
||||
num_labels=num_labels
|
||||
)
|
||||
logger.info("BGE-Reranker model loaded successfully")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load BGE-Reranker model: {e}")
|
||||
raise RuntimeError(f"Model initialization failed: {e}")
|
||||
# Add classification head
|
||||
if self.use_base_model:
|
||||
self.classifier = nn.Linear(self.config.hidden_size, num_labels) # type: ignore[attr-defined]
|
||||
self.dropout = nn.Dropout(
|
||||
self.config.hidden_dropout_prob if hasattr(self.config, 'hidden_dropout_prob') else 0.1) # type: ignore[attr-defined]
|
||||
# Enable gradient checkpointing
|
||||
if gradient_checkpointing:
|
||||
try:
|
||||
self.model.gradient_checkpointing_enable()
|
||||
except Exception:
|
||||
logger.warning("Gradient checkpointing requested but not supported by this model.")
|
||||
# Set device
|
||||
if device:
|
||||
try:
|
||||
self.device = torch.device(device)
|
||||
self.to(self.device)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to move model to {device}: {e}, using CPU")
|
||||
self.device = torch.device('cpu')
|
||||
self.to(self.device)
|
||||
else:
|
||||
self.device = next(self.parameters()).device
|
||||
# Mixed precision
|
||||
if self.use_fp16:
|
||||
try:
|
||||
self.half()
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to set model to half precision: {e}")
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
attention_mask: torch.Tensor,
|
||||
token_type_ids: Optional[torch.Tensor] = None,
|
||||
labels: Optional[torch.Tensor] = None,
|
||||
return_dict: bool = True
|
||||
) -> Union[torch.Tensor, RerankerOutput]:
|
||||
"""
|
||||
Forward pass for cross-encoder reranking
|
||||
|
||||
Args:
|
||||
input_ids: Token IDs [batch_size, seq_length]
|
||||
attention_mask: Attention mask [batch_size, seq_length]
|
||||
token_type_ids: Token type IDs (optional)
|
||||
labels: Labels for training (optional)
|
||||
return_dict: Return RerankerOutput object
|
||||
|
||||
Returns:
|
||||
Scores or RerankerOutput
|
||||
"""
|
||||
if self.use_base_model:
|
||||
# Use base model + custom head
|
||||
outputs = self.model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
return_dict=True
|
||||
)
|
||||
|
||||
# Get pooled output (CLS token)
|
||||
if hasattr(outputs, 'pooler_output') and outputs.pooler_output is not None:
|
||||
pooled_output = outputs.pooler_output
|
||||
else:
|
||||
# Manual pooling if pooler not available
|
||||
last_hidden_state = outputs.last_hidden_state
|
||||
pooled_output = last_hidden_state[:, 0]
|
||||
|
||||
# Apply dropout and classifier
|
||||
pooled_output = self.dropout(pooled_output)
|
||||
logits = self.classifier(pooled_output)
|
||||
|
||||
embeddings = pooled_output
|
||||
else:
|
||||
# Use sequence classification model
|
||||
outputs = self.model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
labels=labels,
|
||||
return_dict=True
|
||||
)
|
||||
logits = outputs.logits
|
||||
embeddings = None
|
||||
|
||||
# Convert logits to scores (sigmoid for binary classification)
|
||||
if self.num_labels == 1:
|
||||
scores = torch.sigmoid(logits.squeeze(-1))
|
||||
else:
|
||||
scores = F.softmax(logits, dim=-1)
|
||||
|
||||
if not return_dict:
|
||||
return scores
|
||||
|
||||
return RerankerOutput(
|
||||
scores=scores,
|
||||
logits=logits,
|
||||
embeddings=embeddings
|
||||
)
|
||||
|
||||
def predict(
|
||||
self,
|
||||
queries: Union[str, List[str]],
|
||||
passages: Union[str, List[str], List[List[str]]],
|
||||
batch_size: int = 32,
|
||||
max_length: int = 512,
|
||||
show_progress_bar: bool = False,
|
||||
convert_to_numpy: bool = True,
|
||||
normalize_scores: bool = False
|
||||
) -> Union[np.ndarray, torch.Tensor, List[float]]:
|
||||
"""
|
||||
Predict reranking scores for query-passage pairs
|
||||
|
||||
Args:
|
||||
queries: Query or list of queries
|
||||
passages: Passage(s) or list of passages per query
|
||||
batch_size: Batch size for inference
|
||||
max_length: Maximum sequence length
|
||||
show_progress_bar: Show progress bar
|
||||
convert_to_numpy: Convert to numpy array
|
||||
normalize_scores: Normalize scores to [0, 1]
|
||||
|
||||
Returns:
|
||||
Reranking scores
|
||||
"""
|
||||
# Handle input formats
|
||||
if isinstance(queries, str):
|
||||
queries = [queries]
|
||||
|
||||
if isinstance(passages, str):
|
||||
passages = [[passages]]
|
||||
elif isinstance(passages[0], str):
|
||||
passages = [passages] # type: ignore[assignment]
|
||||
|
||||
# Ensure queries and passages match
|
||||
if len(queries) == 1 and len(passages) > 1:
|
||||
queries = queries * len(passages)
|
||||
|
||||
# Create pairs
|
||||
pairs = []
|
||||
for query, passage_list in zip(queries, passages):
|
||||
if isinstance(passage_list, str):
|
||||
passage_list = [passage_list]
|
||||
for passage in passage_list:
|
||||
pairs.append((query, passage))
|
||||
|
||||
# Load tokenizer
|
||||
tokenizer = AutoTokenizer.from_pretrained(self.model_name_or_path)
|
||||
|
||||
# Score in batches
|
||||
all_scores = []
|
||||
|
||||
if show_progress_bar:
|
||||
from tqdm import tqdm
|
||||
pbar = tqdm(total=len(pairs), desc="Scoring")
|
||||
|
||||
for i in range(0, len(pairs), batch_size):
|
||||
batch_pairs = pairs[i:i + batch_size]
|
||||
|
||||
# Tokenize
|
||||
batch_queries = [pair[0] for pair in batch_pairs]
|
||||
batch_passages = [pair[1] for pair in batch_pairs]
|
||||
|
||||
encoded = tokenizer(
|
||||
batch_queries,
|
||||
batch_passages,
|
||||
padding=True,
|
||||
truncation=True,
|
||||
max_length=max_length,
|
||||
return_tensors='pt'
|
||||
)
|
||||
|
||||
# Move to device
|
||||
encoded = {k: v.to(self.device) for k, v in encoded.items()}
|
||||
|
||||
# Get scores
|
||||
with torch.no_grad():
|
||||
outputs = self.forward(**encoded, return_dict=True)
|
||||
scores = outputs.scores # type: ignore[attr-defined]
|
||||
|
||||
all_scores.append(scores)
|
||||
|
||||
if show_progress_bar:
|
||||
pbar.update(len(batch_pairs))
|
||||
|
||||
if show_progress_bar:
|
||||
pbar.close()
|
||||
|
||||
# Concatenate scores
|
||||
all_scores = torch.cat(all_scores, dim=0)
|
||||
|
||||
# Normalize if requested
|
||||
if normalize_scores:
|
||||
all_scores = (all_scores - all_scores.min()) / (all_scores.max() - all_scores.min() + 1e-8)
|
||||
|
||||
# Convert format
|
||||
if convert_to_numpy:
|
||||
all_scores = all_scores.cpu().numpy()
|
||||
elif not isinstance(all_scores, torch.Tensor):
|
||||
all_scores = all_scores.cpu()
|
||||
|
||||
# Reshape if single query
|
||||
if len(queries) == 1 and len(passages) == 1:
|
||||
return all_scores.item() if convert_to_numpy else all_scores # type: ignore[attr-defined]
|
||||
|
||||
return all_scores
|
||||
|
||||
def rerank(
|
||||
self,
|
||||
query: str,
|
||||
passages: List[str],
|
||||
top_k: Optional[int] = None,
|
||||
return_scores: bool = False,
|
||||
batch_size: int = 32,
|
||||
max_length: int = 512
|
||||
) -> Union[List[str], Tuple[List[str], List[float]]]:
|
||||
"""
|
||||
Rerank passages for a query
|
||||
|
||||
Args:
|
||||
query: Query string
|
||||
passages: List of passages to rerank
|
||||
top_k: Return only top-k passages
|
||||
return_scores: Also return scores
|
||||
batch_size: Batch size for scoring
|
||||
max_length: Maximum sequence length
|
||||
|
||||
Returns:
|
||||
Reranked passages (and optionally scores)
|
||||
"""
|
||||
# Get scores
|
||||
scores = self.predict(
|
||||
query,
|
||||
passages,
|
||||
batch_size=batch_size,
|
||||
max_length=max_length,
|
||||
convert_to_numpy=True
|
||||
)
|
||||
|
||||
# Sort by scores (descending)
|
||||
sorted_indices = np.argsort(scores)[::-1]
|
||||
|
||||
# Apply top_k if specified
|
||||
if top_k is not None:
|
||||
sorted_indices = sorted_indices[:top_k]
|
||||
|
||||
# Get reranked passages
|
||||
reranked_passages = [passages[i] for i in sorted_indices]
|
||||
|
||||
if return_scores:
|
||||
reranked_scores = [float(scores[i]) for i in sorted_indices]
|
||||
return reranked_passages, reranked_scores
|
||||
|
||||
return reranked_passages
|
||||
|
||||
def save_pretrained(self, save_path: str):
|
||||
"""Save model to directory, including custom heads and config."""
|
||||
try:
|
||||
import os
|
||||
import json
|
||||
os.makedirs(save_path, exist_ok=True)
|
||||
|
||||
# Save the base model
|
||||
self.model.save_pretrained(save_path)
|
||||
|
||||
# Save additional configuration (including [reranker] section)
|
||||
from utils.config_loader import get_config
|
||||
config_dict = get_config().get_section('reranker')
|
||||
config = {
|
||||
'use_fp16': self.use_fp16 if hasattr(self, 'use_fp16') else False,
|
||||
'num_labels': getattr(self, 'num_labels', 1),
|
||||
'reranker_config': config_dict
|
||||
}
|
||||
with open(f"{save_path}/reranker_config.json", 'w') as f:
|
||||
json.dump(config, f, indent=2)
|
||||
|
||||
# Save custom heads if present
|
||||
if hasattr(self, 'classifier'):
|
||||
torch.save(self.classifier.state_dict(), os.path.join(save_path, 'classifier.pt'))
|
||||
if hasattr(self, 'dropout'):
|
||||
torch.save(self.dropout.state_dict(), os.path.join(save_path, 'dropout.pt'))
|
||||
|
||||
# Save tokenizer if available
|
||||
if hasattr(self, '_tokenizer') and self._tokenizer is not None:
|
||||
self._tokenizer.save_pretrained(save_path) # type: ignore[attr-defined]
|
||||
|
||||
logger.info(f"Reranker model and config saved to {save_path}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save reranker model: {e}")
|
||||
raise
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, model_path: str, **kwargs):
|
||||
"""Load model from directory, restoring custom heads and config."""
|
||||
import os
|
||||
import json
|
||||
try:
|
||||
# Load config
|
||||
config_path = os.path.join(model_path, "reranker_config.json")
|
||||
if os.path.exists(config_path):
|
||||
with open(config_path, 'r') as f:
|
||||
config = json.load(f)
|
||||
else:
|
||||
config = {}
|
||||
|
||||
# Instantiate model
|
||||
model = cls(
|
||||
model_name_or_path=model_path,
|
||||
num_labels=config.get('num_labels', 1),
|
||||
use_fp16=config.get('use_fp16', False),
|
||||
**kwargs
|
||||
)
|
||||
|
||||
# Load base model weights
|
||||
model.model = model.model.from_pretrained(model_path)
|
||||
|
||||
# Load custom heads if present
|
||||
classifier_path = os.path.join(model_path, 'classifier.pt')
|
||||
if hasattr(model, 'classifier') and os.path.exists(classifier_path):
|
||||
model.classifier.load_state_dict(torch.load(classifier_path, map_location='cpu'))
|
||||
dropout_path = os.path.join(model_path, 'dropout.pt')
|
||||
if hasattr(model, 'dropout') and os.path.exists(dropout_path):
|
||||
model.dropout.load_state_dict(torch.load(dropout_path, map_location='cpu'))
|
||||
|
||||
logger.info(f"Loaded BGE-Reranker model from {model_path}")
|
||||
return model
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load reranker model from {model_path}: {e}")
|
||||
raise
|
||||
|
||||
def compute_loss(
|
||||
self,
|
||||
scores: torch.Tensor,
|
||||
labels: torch.Tensor,
|
||||
loss_type: str = "binary_cross_entropy"
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Compute loss for training
|
||||
|
||||
Args:
|
||||
scores: Model scores
|
||||
labels: Ground truth labels
|
||||
loss_type: Type of loss function
|
||||
|
||||
Returns:
|
||||
Loss value
|
||||
"""
|
||||
if loss_type == "binary_cross_entropy":
|
||||
return F.binary_cross_entropy(scores, labels.float())
|
||||
elif loss_type == "cross_entropy":
|
||||
return F.cross_entropy(scores, labels.long())
|
||||
elif loss_type == "mse":
|
||||
return F.mse_loss(scores, labels.float())
|
||||
else:
|
||||
raise ValueError(f"Unknown loss type: {loss_type}")
|
||||
Reference in New Issue
Block a user