490 lines
20 KiB
Python
490 lines
20 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
from torch.utils.data import DataLoader
|
|
from torch.optim import AdamW
|
|
from typing import Dict, Optional, List, Tuple, Any
|
|
import logging
|
|
import numpy as np
|
|
from torch.nn import functional as F
|
|
from tqdm import tqdm
|
|
import os
|
|
|
|
from .trainer import BaseTrainer
|
|
from models.bge_reranker import BGERerankerModel
|
|
from models.losses import ListwiseCrossEntropyLoss, PairwiseMarginLoss
|
|
from data.dataset import BGERerankerDataset
|
|
from config.reranker_config import BGERerankerConfig
|
|
from evaluation.metrics import compute_reranker_metrics
|
|
from utils.distributed import is_main_process
|
|
|
|
# AMP imports
|
|
from torch.cuda.amp.autocast_mode import autocast
|
|
from torch.cuda.amp.grad_scaler import GradScaler
|
|
|
|
# Scheduler import
|
|
try:
|
|
from transformers import get_linear_schedule_with_warmup
|
|
except ImportError:
|
|
get_linear_schedule_with_warmup = None
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class BGERerankerTrainer(BaseTrainer):
|
|
"""
|
|
Trainer for BGE-Reranker cross-encoder model
|
|
|
|
Notes:
|
|
- All parameters (config, model, tokenizer, optimizer, scheduler, device, etc.) should be passed from config-driven training scripts.
|
|
- This class does not load config directly; pass config values from your main script for consistency.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
config: BGERerankerConfig,
|
|
model: Optional[BGERerankerModel] = None,
|
|
tokenizer=None,
|
|
optimizer: Optional[torch.optim.Optimizer] = None,
|
|
scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
|
|
device: Optional[torch.device] = None
|
|
):
|
|
# Initialize model if not provided
|
|
if model is None:
|
|
model = BGERerankerModel(
|
|
model_name_or_path=config.model_name_or_path,
|
|
num_labels=config.num_labels,
|
|
cache_dir=config.cache_dir,
|
|
use_fp16=config.fp16,
|
|
gradient_checkpointing=config.gradient_checkpointing
|
|
)
|
|
|
|
# Initialize optimizer if not provided
|
|
if optimizer is None:
|
|
optimizer = self._create_optimizer(model, config)
|
|
|
|
# Initialize parent class
|
|
super().__init__(config, model, optimizer, scheduler, device)
|
|
|
|
self.tokenizer = tokenizer
|
|
self.config = config
|
|
|
|
# Initialize losses based on config
|
|
if config.loss_type == "listwise_ce":
|
|
self.loss_fn = ListwiseCrossEntropyLoss(
|
|
temperature=getattr(config, 'temperature', 1.0)
|
|
)
|
|
elif config.loss_type == "pairwise_margin":
|
|
self.loss_fn = PairwiseMarginLoss(margin=config.margin)
|
|
elif config.loss_type == "combined":
|
|
self.listwise_loss = ListwiseCrossEntropyLoss(
|
|
temperature=getattr(config, 'temperature', 1.0)
|
|
)
|
|
self.pairwise_loss = PairwiseMarginLoss(margin=config.margin)
|
|
else:
|
|
self.loss_fn = nn.BCEWithLogitsLoss()
|
|
|
|
# Knowledge distillation setup
|
|
self.use_kd = config.use_knowledge_distillation
|
|
if self.use_kd:
|
|
self.kd_loss_fn = nn.KLDivLoss(reduction='batchmean')
|
|
self.kd_temperature = config.distillation_temperature
|
|
self.kd_loss_weight = config.distillation_alpha
|
|
|
|
# Hard negative configuration
|
|
self.hard_negative_ratio = config.hard_negative_ratio
|
|
self.use_denoising = config.use_denoising
|
|
self.denoise_threshold = getattr(config, 'denoise_confidence_threshold', 0.5)
|
|
|
|
def _create_optimizer(self, model: nn.Module, config: BGERerankerConfig) -> torch.optim.Optimizer:
|
|
"""Create optimizer with proper parameter groups"""
|
|
# Separate parameters
|
|
no_decay = ['bias', 'LayerNorm.weight']
|
|
optimizer_grouped_parameters = [
|
|
{
|
|
'params': [p for n, p in model.named_parameters()
|
|
if not any(nd in n for nd in no_decay)],
|
|
'weight_decay': config.weight_decay,
|
|
},
|
|
{
|
|
'params': [p for n, p in model.named_parameters()
|
|
if any(nd in n for nd in no_decay)],
|
|
'weight_decay': 0.0,
|
|
}
|
|
]
|
|
|
|
optimizer = AdamW(
|
|
optimizer_grouped_parameters,
|
|
lr=config.learning_rate,
|
|
eps=config.adam_epsilon,
|
|
betas=(config.adam_beta1, config.adam_beta2)
|
|
)
|
|
|
|
return optimizer
|
|
|
|
def create_scheduler(self, num_training_steps: int) -> Optional[object]:
|
|
"""Create learning rate scheduler"""
|
|
if get_linear_schedule_with_warmup is None:
|
|
logger.warning("transformers not available, using constant learning rate")
|
|
return None
|
|
|
|
num_warmup_steps = self.config.warmup_steps or int(num_training_steps * self.config.warmup_ratio)
|
|
|
|
self.scheduler = get_linear_schedule_with_warmup(
|
|
self.optimizer,
|
|
num_warmup_steps=num_warmup_steps,
|
|
num_training_steps=num_training_steps
|
|
)
|
|
return self.scheduler
|
|
|
|
def compute_loss(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
|
|
"""Compute reranker loss"""
|
|
# Get model outputs
|
|
outputs = self.model.forward(
|
|
input_ids=batch['input_ids'],
|
|
attention_mask=batch['attention_mask'],
|
|
return_dict=True
|
|
)
|
|
|
|
if isinstance(outputs, dict):
|
|
logits = outputs['logits']
|
|
else:
|
|
logits = outputs.scores
|
|
|
|
labels = batch['labels']
|
|
|
|
# Handle padding labels
|
|
if labels.dim() > 1:
|
|
# Remove padding (-1 labels)
|
|
valid_mask = labels != -1
|
|
if valid_mask.any():
|
|
logits = logits[valid_mask]
|
|
labels = labels[valid_mask]
|
|
else:
|
|
# If all labels are padding, return zero loss
|
|
return torch.tensor(0.0, device=logits.device, requires_grad=True)
|
|
|
|
# Handle different loss types
|
|
if self.config.loss_type == "listwise_ce":
|
|
# Listwise loss expects grouped format
|
|
try:
|
|
groups = [self.config.train_group_size] * (len(labels) // self.config.train_group_size)
|
|
if groups:
|
|
loss = self.loss_fn(logits.squeeze(-1), labels.float(), groups)
|
|
else:
|
|
loss = F.binary_cross_entropy_with_logits(logits.squeeze(-1), labels.float())
|
|
except Exception as e:
|
|
# Fallback to BCE with specific error logging
|
|
logger.warning(f"Listwise CE loss failed: {e}. Falling back to BCE.")
|
|
loss = F.binary_cross_entropy_with_logits(logits.squeeze(-1), labels.float())
|
|
|
|
elif self.config.loss_type == "pairwise_margin":
|
|
# Extract positive and negative scores
|
|
pos_mask = labels == 1
|
|
neg_mask = labels == 0
|
|
|
|
if pos_mask.any() and neg_mask.any():
|
|
groups = [self.config.train_group_size] * (len(labels) // self.config.train_group_size)
|
|
try:
|
|
loss = self.loss_fn(logits.squeeze(-1), labels.float(), groups)
|
|
except Exception as e:
|
|
# Fallback to BCE with specific error logging
|
|
logger.warning(f"Pairwise margin loss failed: {e}. Falling back to BCE.")
|
|
loss = F.binary_cross_entropy_with_logits(logits.squeeze(-1), labels.float())
|
|
else:
|
|
# Fallback to BCE if no proper pairs
|
|
logger.warning("No positive-negative pairs found for pairwise loss. Using BCE.")
|
|
loss = F.binary_cross_entropy_with_logits(logits.squeeze(-1), labels.float())
|
|
|
|
elif self.config.loss_type == "combined":
|
|
# Combine listwise and pairwise losses
|
|
try:
|
|
groups = [self.config.train_group_size] * (len(labels) // self.config.train_group_size)
|
|
if groups:
|
|
listwise_loss = self.listwise_loss(logits.squeeze(-1), labels.float(), groups)
|
|
pairwise_loss = self.pairwise_loss(logits.squeeze(-1), labels.float(), groups)
|
|
loss = self.config.listwise_weight * listwise_loss + self.config.pairwise_weight * pairwise_loss
|
|
else:
|
|
loss = F.binary_cross_entropy_with_logits(logits.squeeze(-1), labels.float())
|
|
except Exception as e:
|
|
# Fallback to BCE with specific error logging
|
|
logger.warning(f"Combined loss computation failed: {e}. Falling back to BCE.")
|
|
loss = F.binary_cross_entropy_with_logits(logits.squeeze(-1), labels.float())
|
|
|
|
else:
|
|
# Default binary cross-entropy
|
|
loss = F.binary_cross_entropy_with_logits(logits.squeeze(-1), labels.float())
|
|
|
|
# Add knowledge distillation loss if enabled
|
|
if self.use_kd and 'teacher_scores' in batch:
|
|
teacher_scores = batch['teacher_scores']
|
|
student_probs = F.log_softmax(logits / self.kd_temperature, dim=-1)
|
|
teacher_probs = F.softmax(teacher_scores / self.kd_temperature, dim=-1)
|
|
|
|
kd_loss = self.kd_loss_fn(student_probs, teacher_probs) * (self.kd_temperature ** 2)
|
|
loss = (1 - self.kd_loss_weight) * loss + self.kd_loss_weight * kd_loss
|
|
|
|
# Apply denoising if enabled
|
|
if self.use_denoising and 'teacher_scores' in batch:
|
|
# Denoise by confidence thresholding
|
|
confidence_scores = torch.sigmoid(batch['teacher_scores'])
|
|
confidence_mask = confidence_scores > self.denoise_threshold
|
|
if confidence_mask.any():
|
|
loss = loss * confidence_mask.float().mean()
|
|
|
|
return loss
|
|
|
|
def evaluate_step(self, batch: Dict[str, torch.Tensor]) -> Dict[str, float]:
|
|
"""Evaluate a single batch"""
|
|
with torch.no_grad():
|
|
# Get model outputs
|
|
outputs = self.model.forward(
|
|
input_ids=batch['input_ids'],
|
|
attention_mask=batch['attention_mask'],
|
|
return_dict=True
|
|
)
|
|
|
|
if isinstance(outputs, dict):
|
|
logits = outputs['logits']
|
|
else:
|
|
logits = outputs.scores
|
|
|
|
labels = batch['labels']
|
|
|
|
# Handle padding labels
|
|
if labels.dim() > 1:
|
|
valid_mask = labels != -1
|
|
if valid_mask.any():
|
|
logits = logits[valid_mask]
|
|
labels = labels[valid_mask]
|
|
else:
|
|
return {'loss': 0.0, 'accuracy': 0.0}
|
|
|
|
# Compute loss
|
|
loss = self.compute_loss(batch)
|
|
|
|
# Compute metrics based on group size
|
|
if hasattr(self.config, 'train_group_size') and self.config.train_group_size > 1:
|
|
# Listwise evaluation
|
|
groups = [self.config.train_group_size] * (len(labels) // self.config.train_group_size)
|
|
if groups:
|
|
try:
|
|
metrics = compute_reranker_metrics(
|
|
logits.squeeze(-1).cpu().numpy(),
|
|
labels.cpu().numpy(),
|
|
groups=groups,
|
|
k_values=[1, 3, 5, 10]
|
|
)
|
|
metrics['loss'] = loss.item()
|
|
except Exception as e:
|
|
logger.warning(f"Error computing reranker metrics: {e}")
|
|
metrics = {'loss': loss.item(), 'accuracy': 0.0}
|
|
else:
|
|
metrics = {'loss': loss.item(), 'accuracy': 0.0}
|
|
else:
|
|
# Binary evaluation
|
|
predictions = (torch.sigmoid(logits.squeeze(-1)) > 0.5).float()
|
|
accuracy = (predictions == labels.float()).float().mean()
|
|
|
|
metrics = {
|
|
'accuracy': accuracy.item(),
|
|
'loss': loss.item()
|
|
}
|
|
|
|
# Add AUC if we have both positive and negative examples
|
|
if len(torch.unique(labels)) > 1:
|
|
try:
|
|
from sklearn.metrics import roc_auc_score
|
|
auc = roc_auc_score(labels.cpu().numpy(), torch.sigmoid(logits.squeeze(-1)).cpu().numpy())
|
|
metrics['auc'] = float(auc)
|
|
except Exception as e:
|
|
logger.warning(f"Error computing AUC: {e}")
|
|
|
|
return metrics
|
|
|
|
def train_with_hard_negative_mining(
|
|
self,
|
|
train_dataset: BGERerankerDataset,
|
|
eval_dataset: Optional[BGERerankerDataset] = None,
|
|
retriever_model: Optional[nn.Module] = None
|
|
):
|
|
"""
|
|
Train with periodic hard negative mining from retriever
|
|
"""
|
|
if retriever_model is None:
|
|
logger.warning("No retriever model provided for hard negative mining. Using standard training.")
|
|
train_dataloader = self._create_dataloader(train_dataset, is_train=True)
|
|
eval_dataloader = self._create_dataloader(eval_dataset, is_train=False) if eval_dataset else None
|
|
|
|
# Only proceed if we have a valid training dataloader
|
|
if train_dataloader is not None:
|
|
return super().train(train_dataloader, eval_dataloader)
|
|
else:
|
|
logger.error("No valid training dataloader created")
|
|
return
|
|
|
|
# Training loop with hard negative updates
|
|
for epoch in range(self.config.num_train_epochs):
|
|
# Mine hard negatives at the beginning of each epoch
|
|
if epoch > 0:
|
|
logger.info(f"Mining hard negatives for epoch {epoch + 1}")
|
|
self._mine_hard_negatives(train_dataset, retriever_model)
|
|
|
|
# Create data loader with new negatives
|
|
train_dataloader = self._create_dataloader(train_dataset, is_train=True)
|
|
eval_dataloader = self._create_dataloader(eval_dataset, is_train=False) if eval_dataset else None
|
|
|
|
# Only proceed if we have a valid training dataloader
|
|
if train_dataloader is not None:
|
|
# Train one epoch
|
|
super().train(train_dataloader, eval_dataloader, num_epochs=1)
|
|
else:
|
|
logger.error(f"No valid training dataloader for epoch {epoch + 1}")
|
|
continue
|
|
|
|
# Update current epoch
|
|
self.current_epoch = epoch
|
|
|
|
def _create_dataloader(self, dataset, is_train: bool = True) -> Optional[DataLoader]:
|
|
"""Create data loader"""
|
|
if dataset is None:
|
|
return None
|
|
|
|
from data.dataset import collate_reranker_batch
|
|
|
|
batch_size = self.config.per_device_train_batch_size if is_train else self.config.per_device_eval_batch_size
|
|
|
|
# Setup distributed sampler if needed
|
|
sampler = None
|
|
if self.is_distributed:
|
|
from torch.utils.data.distributed import DistributedSampler
|
|
sampler = DistributedSampler(dataset, shuffle=is_train)
|
|
|
|
return DataLoader(
|
|
dataset,
|
|
batch_size=batch_size,
|
|
shuffle=is_train and sampler is None,
|
|
sampler=sampler,
|
|
num_workers=self.config.dataloader_num_workers,
|
|
pin_memory=self.config.dataloader_pin_memory,
|
|
drop_last=is_train,
|
|
collate_fn=collate_reranker_batch
|
|
)
|
|
|
|
def _mine_hard_negatives(
|
|
self,
|
|
dataset: BGERerankerDataset,
|
|
retriever_model: Any
|
|
):
|
|
"""Mine hard negatives using the retriever model and update the dataset's negatives."""
|
|
logger.info("Starting hard negative mining using retriever model...")
|
|
num_hard_negatives = getattr(self.config, 'num_hard_negatives', 10)
|
|
retriever_batch_size = getattr(self.config, 'retriever_batch_size', 32)
|
|
mined_count = 0
|
|
|
|
# Collect all unique candidate passages from the dataset
|
|
all_passages = set()
|
|
for item in dataset.data:
|
|
all_passages.update(item.get('pos', []))
|
|
all_passages.update(item.get('neg', []))
|
|
all_passages = list(all_passages)
|
|
|
|
if not all_passages:
|
|
logger.warning("No candidate passages found in dataset for hard negative mining.")
|
|
return
|
|
|
|
# Encode all candidate passages
|
|
try:
|
|
passage_embeddings = retriever_model.encode(
|
|
all_passages,
|
|
batch_size=retriever_batch_size,
|
|
convert_to_numpy=True
|
|
)
|
|
except Exception as e:
|
|
logger.error(f"Error encoding passages for hard negative mining: {e}")
|
|
return
|
|
|
|
# For each query, mine hard negatives
|
|
for idx in range(len(dataset)):
|
|
item = dataset.data[idx]
|
|
query = item['query']
|
|
positives = set(item.get('pos', []))
|
|
existing_negs = set(item.get('neg', []))
|
|
|
|
try:
|
|
query_emb = retriever_model.encode([query], batch_size=1, convert_to_numpy=True)[0]
|
|
|
|
# Compute similarity
|
|
scores = np.dot(passage_embeddings, query_emb)
|
|
|
|
# Rank passages by similarity
|
|
ranked_indices = np.argsort(scores)[::-1]
|
|
|
|
new_negs = []
|
|
for i in ranked_indices:
|
|
passage = all_passages[int(i)] # Convert to int for indexing
|
|
if passage not in positives and passage not in existing_negs:
|
|
new_negs.append(passage)
|
|
if len(new_negs) >= num_hard_negatives:
|
|
break
|
|
|
|
if new_negs:
|
|
item['neg'] = list(existing_negs) + new_negs
|
|
mined_count += 1
|
|
logger.debug(f"Query idx {idx}: Added {len(new_negs)} hard negatives.")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error mining hard negatives for query idx {idx}: {e}")
|
|
|
|
logger.info(f"Hard negative mining complete. Updated negatives for {mined_count} queries.")
|
|
|
|
def save_model(self, save_path: str):
|
|
"""Save the trained model"""
|
|
try:
|
|
os.makedirs(save_path, exist_ok=True)
|
|
|
|
# Save model
|
|
model_to_save = self.model.module if hasattr(self.model, 'module') else self.model
|
|
if hasattr(model_to_save, 'save_pretrained'):
|
|
model_to_save.save_pretrained(save_path) # type: ignore
|
|
else:
|
|
# Fallback to torch.save
|
|
torch.save(model_to_save.state_dict(), os.path.join(save_path, 'pytorch_model.bin')) # type: ignore
|
|
|
|
# Save tokenizer if available
|
|
if self.tokenizer and hasattr(self.tokenizer, 'save_pretrained'):
|
|
self.tokenizer.save_pretrained(save_path)
|
|
|
|
# Save config
|
|
if hasattr(self.config, 'save'):
|
|
self.config.save(os.path.join(save_path, 'training_config.json'))
|
|
|
|
logger.info(f"Model saved to {save_path}")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error saving model: {e}")
|
|
raise
|
|
|
|
def export_for_inference(self, output_path: str):
|
|
"""Export model for efficient inference"""
|
|
if is_main_process():
|
|
logger.info(f"Exporting model for inference to {output_path}")
|
|
|
|
# Save model in inference mode
|
|
self.model.eval()
|
|
|
|
# Save model state
|
|
torch.save({
|
|
'model_state_dict': self.model.state_dict(),
|
|
'config': self.config.to_dict(),
|
|
'model_type': 'bge_reranker'
|
|
}, output_path)
|
|
|
|
# Also save in HuggingFace format if possible
|
|
try:
|
|
hf_path = output_path.replace('.pt', '_hf')
|
|
self.save_model(hf_path)
|
|
except Exception as e:
|
|
logger.warning(f"Failed to save in HuggingFace format: {e}")
|
|
|
|
logger.info(f"Model exported successfully")
|