bge_finetune/training/reranker_trainer.py
2025-07-22 16:55:25 +08:00

490 lines
20 KiB
Python

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.optim import AdamW
from typing import Dict, Optional, List, Tuple, Any
import logging
import numpy as np
from torch.nn import functional as F
from tqdm import tqdm
import os
from .trainer import BaseTrainer
from models.bge_reranker import BGERerankerModel
from models.losses import ListwiseCrossEntropyLoss, PairwiseMarginLoss
from data.dataset import BGERerankerDataset
from config.reranker_config import BGERerankerConfig
from evaluation.metrics import compute_reranker_metrics
from utils.distributed import is_main_process
# AMP imports
from torch.cuda.amp.autocast_mode import autocast
from torch.cuda.amp.grad_scaler import GradScaler
# Scheduler import
try:
from transformers import get_linear_schedule_with_warmup
except ImportError:
get_linear_schedule_with_warmup = None
logger = logging.getLogger(__name__)
class BGERerankerTrainer(BaseTrainer):
"""
Trainer for BGE-Reranker cross-encoder model
Notes:
- All parameters (config, model, tokenizer, optimizer, scheduler, device, etc.) should be passed from config-driven training scripts.
- This class does not load config directly; pass config values from your main script for consistency.
"""
def __init__(
self,
config: BGERerankerConfig,
model: Optional[BGERerankerModel] = None,
tokenizer=None,
optimizer: Optional[torch.optim.Optimizer] = None,
scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
device: Optional[torch.device] = None
):
# Initialize model if not provided
if model is None:
model = BGERerankerModel(
model_name_or_path=config.model_name_or_path,
num_labels=config.num_labels,
cache_dir=config.cache_dir,
use_fp16=config.fp16,
gradient_checkpointing=config.gradient_checkpointing
)
# Initialize optimizer if not provided
if optimizer is None:
optimizer = self._create_optimizer(model, config)
# Initialize parent class
super().__init__(config, model, optimizer, scheduler, device)
self.tokenizer = tokenizer
self.config = config
# Initialize losses based on config
if config.loss_type == "listwise_ce":
self.loss_fn = ListwiseCrossEntropyLoss(
temperature=getattr(config, 'temperature', 1.0)
)
elif config.loss_type == "pairwise_margin":
self.loss_fn = PairwiseMarginLoss(margin=config.margin)
elif config.loss_type == "combined":
self.listwise_loss = ListwiseCrossEntropyLoss(
temperature=getattr(config, 'temperature', 1.0)
)
self.pairwise_loss = PairwiseMarginLoss(margin=config.margin)
else:
self.loss_fn = nn.BCEWithLogitsLoss()
# Knowledge distillation setup
self.use_kd = config.use_knowledge_distillation
if self.use_kd:
self.kd_loss_fn = nn.KLDivLoss(reduction='batchmean')
self.kd_temperature = config.distillation_temperature
self.kd_loss_weight = config.distillation_alpha
# Hard negative configuration
self.hard_negative_ratio = config.hard_negative_ratio
self.use_denoising = config.use_denoising
self.denoise_threshold = getattr(config, 'denoise_confidence_threshold', 0.5)
def _create_optimizer(self, model: nn.Module, config: BGERerankerConfig) -> torch.optim.Optimizer:
"""Create optimizer with proper parameter groups"""
# Separate parameters
no_decay = ['bias', 'LayerNorm.weight']
optimizer_grouped_parameters = [
{
'params': [p for n, p in model.named_parameters()
if not any(nd in n for nd in no_decay)],
'weight_decay': config.weight_decay,
},
{
'params': [p for n, p in model.named_parameters()
if any(nd in n for nd in no_decay)],
'weight_decay': 0.0,
}
]
optimizer = AdamW(
optimizer_grouped_parameters,
lr=config.learning_rate,
eps=config.adam_epsilon,
betas=(config.adam_beta1, config.adam_beta2)
)
return optimizer
def create_scheduler(self, num_training_steps: int) -> Optional[object]:
"""Create learning rate scheduler"""
if get_linear_schedule_with_warmup is None:
logger.warning("transformers not available, using constant learning rate")
return None
num_warmup_steps = self.config.warmup_steps or int(num_training_steps * self.config.warmup_ratio)
self.scheduler = get_linear_schedule_with_warmup(
self.optimizer,
num_warmup_steps=num_warmup_steps,
num_training_steps=num_training_steps
)
return self.scheduler
def compute_loss(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
"""Compute reranker loss"""
# Get model outputs
outputs = self.model.forward(
input_ids=batch['input_ids'],
attention_mask=batch['attention_mask'],
return_dict=True
)
if isinstance(outputs, dict):
logits = outputs['logits']
else:
logits = outputs.scores
labels = batch['labels']
# Handle padding labels
if labels.dim() > 1:
# Remove padding (-1 labels)
valid_mask = labels != -1
if valid_mask.any():
logits = logits[valid_mask]
labels = labels[valid_mask]
else:
# If all labels are padding, return zero loss
return torch.tensor(0.0, device=logits.device, requires_grad=True)
# Handle different loss types
if self.config.loss_type == "listwise_ce":
# Listwise loss expects grouped format
try:
groups = [self.config.train_group_size] * (len(labels) // self.config.train_group_size)
if groups:
loss = self.loss_fn(logits.squeeze(-1), labels.float(), groups)
else:
loss = F.binary_cross_entropy_with_logits(logits.squeeze(-1), labels.float())
except Exception as e:
# Fallback to BCE with specific error logging
logger.warning(f"Listwise CE loss failed: {e}. Falling back to BCE.")
loss = F.binary_cross_entropy_with_logits(logits.squeeze(-1), labels.float())
elif self.config.loss_type == "pairwise_margin":
# Extract positive and negative scores
pos_mask = labels == 1
neg_mask = labels == 0
if pos_mask.any() and neg_mask.any():
groups = [self.config.train_group_size] * (len(labels) // self.config.train_group_size)
try:
loss = self.loss_fn(logits.squeeze(-1), labels.float(), groups)
except Exception as e:
# Fallback to BCE with specific error logging
logger.warning(f"Pairwise margin loss failed: {e}. Falling back to BCE.")
loss = F.binary_cross_entropy_with_logits(logits.squeeze(-1), labels.float())
else:
# Fallback to BCE if no proper pairs
logger.warning("No positive-negative pairs found for pairwise loss. Using BCE.")
loss = F.binary_cross_entropy_with_logits(logits.squeeze(-1), labels.float())
elif self.config.loss_type == "combined":
# Combine listwise and pairwise losses
try:
groups = [self.config.train_group_size] * (len(labels) // self.config.train_group_size)
if groups:
listwise_loss = self.listwise_loss(logits.squeeze(-1), labels.float(), groups)
pairwise_loss = self.pairwise_loss(logits.squeeze(-1), labels.float(), groups)
loss = self.config.listwise_weight * listwise_loss + self.config.pairwise_weight * pairwise_loss
else:
loss = F.binary_cross_entropy_with_logits(logits.squeeze(-1), labels.float())
except Exception as e:
# Fallback to BCE with specific error logging
logger.warning(f"Combined loss computation failed: {e}. Falling back to BCE.")
loss = F.binary_cross_entropy_with_logits(logits.squeeze(-1), labels.float())
else:
# Default binary cross-entropy
loss = F.binary_cross_entropy_with_logits(logits.squeeze(-1), labels.float())
# Add knowledge distillation loss if enabled
if self.use_kd and 'teacher_scores' in batch:
teacher_scores = batch['teacher_scores']
student_probs = F.log_softmax(logits / self.kd_temperature, dim=-1)
teacher_probs = F.softmax(teacher_scores / self.kd_temperature, dim=-1)
kd_loss = self.kd_loss_fn(student_probs, teacher_probs) * (self.kd_temperature ** 2)
loss = (1 - self.kd_loss_weight) * loss + self.kd_loss_weight * kd_loss
# Apply denoising if enabled
if self.use_denoising and 'teacher_scores' in batch:
# Denoise by confidence thresholding
confidence_scores = torch.sigmoid(batch['teacher_scores'])
confidence_mask = confidence_scores > self.denoise_threshold
if confidence_mask.any():
loss = loss * confidence_mask.float().mean()
return loss
def evaluate_step(self, batch: Dict[str, torch.Tensor]) -> Dict[str, float]:
"""Evaluate a single batch"""
with torch.no_grad():
# Get model outputs
outputs = self.model.forward(
input_ids=batch['input_ids'],
attention_mask=batch['attention_mask'],
return_dict=True
)
if isinstance(outputs, dict):
logits = outputs['logits']
else:
logits = outputs.scores
labels = batch['labels']
# Handle padding labels
if labels.dim() > 1:
valid_mask = labels != -1
if valid_mask.any():
logits = logits[valid_mask]
labels = labels[valid_mask]
else:
return {'loss': 0.0, 'accuracy': 0.0}
# Compute loss
loss = self.compute_loss(batch)
# Compute metrics based on group size
if hasattr(self.config, 'train_group_size') and self.config.train_group_size > 1:
# Listwise evaluation
groups = [self.config.train_group_size] * (len(labels) // self.config.train_group_size)
if groups:
try:
metrics = compute_reranker_metrics(
logits.squeeze(-1).cpu().numpy(),
labels.cpu().numpy(),
groups=groups,
k_values=[1, 3, 5, 10]
)
metrics['loss'] = loss.item()
except Exception as e:
logger.warning(f"Error computing reranker metrics: {e}")
metrics = {'loss': loss.item(), 'accuracy': 0.0}
else:
metrics = {'loss': loss.item(), 'accuracy': 0.0}
else:
# Binary evaluation
predictions = (torch.sigmoid(logits.squeeze(-1)) > 0.5).float()
accuracy = (predictions == labels.float()).float().mean()
metrics = {
'accuracy': accuracy.item(),
'loss': loss.item()
}
# Add AUC if we have both positive and negative examples
if len(torch.unique(labels)) > 1:
try:
from sklearn.metrics import roc_auc_score
auc = roc_auc_score(labels.cpu().numpy(), torch.sigmoid(logits.squeeze(-1)).cpu().numpy())
metrics['auc'] = float(auc)
except Exception as e:
logger.warning(f"Error computing AUC: {e}")
return metrics
def train_with_hard_negative_mining(
self,
train_dataset: BGERerankerDataset,
eval_dataset: Optional[BGERerankerDataset] = None,
retriever_model: Optional[nn.Module] = None
):
"""
Train with periodic hard negative mining from retriever
"""
if retriever_model is None:
logger.warning("No retriever model provided for hard negative mining. Using standard training.")
train_dataloader = self._create_dataloader(train_dataset, is_train=True)
eval_dataloader = self._create_dataloader(eval_dataset, is_train=False) if eval_dataset else None
# Only proceed if we have a valid training dataloader
if train_dataloader is not None:
return super().train(train_dataloader, eval_dataloader)
else:
logger.error("No valid training dataloader created")
return
# Training loop with hard negative updates
for epoch in range(self.config.num_train_epochs):
# Mine hard negatives at the beginning of each epoch
if epoch > 0:
logger.info(f"Mining hard negatives for epoch {epoch + 1}")
self._mine_hard_negatives(train_dataset, retriever_model)
# Create data loader with new negatives
train_dataloader = self._create_dataloader(train_dataset, is_train=True)
eval_dataloader = self._create_dataloader(eval_dataset, is_train=False) if eval_dataset else None
# Only proceed if we have a valid training dataloader
if train_dataloader is not None:
# Train one epoch
super().train(train_dataloader, eval_dataloader, num_epochs=1)
else:
logger.error(f"No valid training dataloader for epoch {epoch + 1}")
continue
# Update current epoch
self.current_epoch = epoch
def _create_dataloader(self, dataset, is_train: bool = True) -> Optional[DataLoader]:
"""Create data loader"""
if dataset is None:
return None
from data.dataset import collate_reranker_batch
batch_size = self.config.per_device_train_batch_size if is_train else self.config.per_device_eval_batch_size
# Setup distributed sampler if needed
sampler = None
if self.is_distributed:
from torch.utils.data.distributed import DistributedSampler
sampler = DistributedSampler(dataset, shuffle=is_train)
return DataLoader(
dataset,
batch_size=batch_size,
shuffle=is_train and sampler is None,
sampler=sampler,
num_workers=self.config.dataloader_num_workers,
pin_memory=self.config.dataloader_pin_memory,
drop_last=is_train,
collate_fn=collate_reranker_batch
)
def _mine_hard_negatives(
self,
dataset: BGERerankerDataset,
retriever_model: Any
):
"""Mine hard negatives using the retriever model and update the dataset's negatives."""
logger.info("Starting hard negative mining using retriever model...")
num_hard_negatives = getattr(self.config, 'num_hard_negatives', 10)
retriever_batch_size = getattr(self.config, 'retriever_batch_size', 32)
mined_count = 0
# Collect all unique candidate passages from the dataset
all_passages = set()
for item in dataset.data:
all_passages.update(item.get('pos', []))
all_passages.update(item.get('neg', []))
all_passages = list(all_passages)
if not all_passages:
logger.warning("No candidate passages found in dataset for hard negative mining.")
return
# Encode all candidate passages
try:
passage_embeddings = retriever_model.encode(
all_passages,
batch_size=retriever_batch_size,
convert_to_numpy=True
)
except Exception as e:
logger.error(f"Error encoding passages for hard negative mining: {e}")
return
# For each query, mine hard negatives
for idx in range(len(dataset)):
item = dataset.data[idx]
query = item['query']
positives = set(item.get('pos', []))
existing_negs = set(item.get('neg', []))
try:
query_emb = retriever_model.encode([query], batch_size=1, convert_to_numpy=True)[0]
# Compute similarity
scores = np.dot(passage_embeddings, query_emb)
# Rank passages by similarity
ranked_indices = np.argsort(scores)[::-1]
new_negs = []
for i in ranked_indices:
passage = all_passages[int(i)] # Convert to int for indexing
if passage not in positives and passage not in existing_negs:
new_negs.append(passage)
if len(new_negs) >= num_hard_negatives:
break
if new_negs:
item['neg'] = list(existing_negs) + new_negs
mined_count += 1
logger.debug(f"Query idx {idx}: Added {len(new_negs)} hard negatives.")
except Exception as e:
logger.error(f"Error mining hard negatives for query idx {idx}: {e}")
logger.info(f"Hard negative mining complete. Updated negatives for {mined_count} queries.")
def save_model(self, save_path: str):
"""Save the trained model"""
try:
os.makedirs(save_path, exist_ok=True)
# Save model
model_to_save = self.model.module if hasattr(self.model, 'module') else self.model
if hasattr(model_to_save, 'save_pretrained'):
model_to_save.save_pretrained(save_path) # type: ignore
else:
# Fallback to torch.save
torch.save(model_to_save.state_dict(), os.path.join(save_path, 'pytorch_model.bin')) # type: ignore
# Save tokenizer if available
if self.tokenizer and hasattr(self.tokenizer, 'save_pretrained'):
self.tokenizer.save_pretrained(save_path)
# Save config
if hasattr(self.config, 'save'):
self.config.save(os.path.join(save_path, 'training_config.json'))
logger.info(f"Model saved to {save_path}")
except Exception as e:
logger.error(f"Error saving model: {e}")
raise
def export_for_inference(self, output_path: str):
"""Export model for efficient inference"""
if is_main_process():
logger.info(f"Exporting model for inference to {output_path}")
# Save model in inference mode
self.model.eval()
# Save model state
torch.save({
'model_state_dict': self.model.state_dict(),
'config': self.config.to_dict(),
'model_type': 'bge_reranker'
}, output_path)
# Also save in HuggingFace format if possible
try:
hf_path = output_path.replace('.pt', '_hf')
self.save_model(hf_path)
except Exception as e:
logger.warning(f"Failed to save in HuggingFace format: {e}")
logger.info(f"Model exported successfully")