bge_finetune/training/joint_trainer.py
2025-07-22 16:55:25 +08:00

420 lines
16 KiB
Python

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.optim import AdamW
from typing import Dict, Optional, List, Tuple
import logging
from tqdm import tqdm
import numpy as np
import os
from .trainer import BaseTrainer
from models.bge_m3 import BGEM3Model
from models.bge_reranker import BGERerankerModel
from models.losses import (
InfoNCELoss, ListwiseCrossEntropyLoss,
DynamicListwiseDistillationLoss, MultiTaskLoss
)
from data.dataset import JointDataset
from config.base_config import BaseConfig
from utils.distributed import is_main_process, get_world_size, reduce_dict
from utils.logging import get_logger
# Import autocast and GradScaler with fallback
from torch.cuda.amp.autocast_mode import autocast
from torch.cuda.amp.grad_scaler import GradScaler
# Scheduler import
try:
from transformers import get_linear_schedule_with_warmup
except ImportError:
get_linear_schedule_with_warmup = None
logger = get_logger(__name__)
class JointTrainer(BaseTrainer):
"""
Joint trainer for BGE-M3 and BGE-Reranker using RocketQAv2 approach
Notes:
- All parameters (config, retriever, reranker, optimizers, device, etc.) should be passed from config-driven training scripts.
- This class does not load config directly; pass config values from your main script for consistency.
"""
def __init__(
self,
config: BaseConfig,
retriever: BGEM3Model,
reranker: BGERerankerModel,
retriever_optimizer: Optional[torch.optim.Optimizer] = None,
reranker_optimizer: Optional[torch.optim.Optimizer] = None,
device: Optional[torch.device] = None
):
# Create default optimizer if none provided
if retriever_optimizer is None:
retriever_optimizer = AdamW(
retriever.parameters(),
lr=config.learning_rate,
weight_decay=config.weight_decay
)
# We'll use the retriever as the main model for the base class
super().__init__(config, retriever, retriever_optimizer, device=device)
self.retriever = self.model # Alias for clarity
self.reranker = reranker
self.reranker.to(self.device)
# Setup reranker optimizer if not provided
if reranker_optimizer is None:
self.reranker_optimizer = AdamW(
self.reranker.parameters(),
lr=config.learning_rate * 2, # Higher LR for reranker
weight_decay=config.weight_decay
)
else:
self.reranker_optimizer = reranker_optimizer
# Initialize losses
self.retriever_loss_fn = InfoNCELoss(
temperature=config.temperature if hasattr(config, 'temperature') else 0.02,
use_in_batch_negatives=True # type: ignore[call-arg]
)
self.reranker_loss_fn = ListwiseCrossEntropyLoss(
temperature=1.0,
label_smoothing=0.0 # type: ignore[call-arg]
)
self.distillation_loss_fn = DynamicListwiseDistillationLoss(
temperature=1.0,
alpha=0.5 # type: ignore[call-arg]
)
# Joint training specific parameters
self.update_retriever_interval = getattr(config, 'update_retriever_steps', 100)
self.joint_loss_weight = getattr(config, 'joint_loss_weight', 0.3)
self.use_dynamic_distillation = getattr(config, 'use_dynamic_listwise_distillation', True)
# Tracking metrics
self.retriever_metrics = []
self.reranker_metrics = []
def compute_loss(self, batch: Tuple[Dict, Dict]) -> torch.Tensor:
"""
Compute joint loss for both retriever and reranker
Args:
batch: Tuple of (retriever_batch, reranker_batch)
"""
retriever_batch, reranker_batch = batch
# 1. Compute retriever loss
retriever_loss = self._compute_retriever_loss(retriever_batch)
# 2. Compute reranker loss
reranker_loss = self._compute_reranker_loss(reranker_batch)
# 3. Compute distillation losses if enabled
if self.use_dynamic_distillation:
# Get scores from both models
with torch.no_grad():
retriever_scores = self._get_retriever_scores(retriever_batch)
reranker_scores = self._get_reranker_scores(reranker_batch)
# Compute mutual distillation
retriever_distill_loss, reranker_distill_loss = self.distillation_loss_fn(
retriever_scores,
reranker_scores,
labels=reranker_batch.get('labels')
)
# Add distillation to main losses
retriever_loss = retriever_loss + self.joint_loss_weight * retriever_distill_loss
reranker_loss = reranker_loss + self.joint_loss_weight * reranker_distill_loss
# Combine losses (we'll handle separate backward passes)
self.current_retriever_loss = retriever_loss
self.current_reranker_loss = reranker_loss
# Return combined loss for logging
return retriever_loss + reranker_loss
def _compute_retriever_loss(self, batch: Dict) -> torch.Tensor:
"""Compute retriever (bi-encoder) loss"""
# Get embeddings
query_outputs = self.retriever(
input_ids=batch['query_input_ids'],
attention_mask=batch['query_attention_mask'],
return_dense=True
)
passage_outputs = self.retriever(
input_ids=batch['passage_input_ids'],
attention_mask=batch['passage_attention_mask'],
return_dense=True
)
# Compute contrastive loss
loss = self.retriever_loss_fn(
query_outputs['dense'],
passage_outputs['dense'],
labels=batch.get('labels')
)
return loss
def _compute_reranker_loss(self, batch: Dict) -> torch.Tensor:
"""Compute reranker (cross-encoder) loss"""
outputs = self.reranker(
input_ids=batch['input_ids'],
attention_mask=batch['attention_mask']
)
logits = outputs['logits']
labels = batch['labels']
# Reshape for listwise loss if needed
if logits.dim() == 1:
# Binary classification
loss = nn.functional.binary_cross_entropy_with_logits(logits, labels.float())
else:
# Listwise ranking
loss = self.reranker_loss_fn(logits, labels)
return loss
def _get_retriever_scores(self, batch: Dict) -> torch.Tensor:
"""Get retriever similarity scores"""
with torch.no_grad():
query_outputs = self.retriever(
input_ids=batch['query_input_ids'],
attention_mask=batch['query_attention_mask'],
return_dense=True
)
passage_outputs = self.retriever(
input_ids=batch['passage_input_ids'],
attention_mask=batch['passage_attention_mask'],
return_dense=True
)
# Compute similarities
scores = torch.matmul(
query_outputs['dense'],
passage_outputs['dense'].transpose(-1, -2)
)
return scores
def _get_reranker_scores(self, batch: Dict) -> torch.Tensor:
"""Get reranker scores"""
outputs = self.reranker(
input_ids=batch['input_ids'],
attention_mask=batch['attention_mask']
)
return outputs['logits']
def _train_epoch(self, dataloader: DataLoader, epoch: int) -> Dict[str, float]:
"""Train both models for one epoch"""
self.retriever.train()
self.reranker.train()
epoch_metrics = {
'retriever_loss': 0.0,
'reranker_loss': 0.0,
'total_loss': 0.0
}
num_batches = 0
# Set epoch for distributed sampler
if hasattr(dataloader, 'sampler') and hasattr(dataloader.sampler, 'set_epoch'):
dataloader.sampler.set_epoch(epoch) # type: ignore[attr-defined]
progress_bar = tqdm(
dataloader,
desc=f"Joint Training Epoch {epoch + 1}",
disable=not is_main_process()
)
for step, batch in enumerate(progress_bar):
# Prepare batch
retriever_batch, reranker_batch = batch
retriever_batch = self._prepare_batch(retriever_batch)
reranker_batch = self._prepare_batch(reranker_batch)
batch = (retriever_batch, reranker_batch)
# Forward pass
total_loss = self.compute_loss(batch)
# Scale losses for gradient accumulation
retriever_loss = self.current_retriever_loss / self.config.gradient_accumulation_steps
reranker_loss = self.current_reranker_loss / self.config.gradient_accumulation_steps
# Backward passes (separate for each model)
if self.scaler:
# Retriever backward
self.scaler.scale(retriever_loss).backward(retain_graph=True) # type: ignore[attr-defined]
# Reranker backward
self.scaler.scale(reranker_loss).backward() # type: ignore[attr-defined]
else:
retriever_loss.backward(retain_graph=True)
reranker_loss.backward()
# Update weights
if (step + 1) % self.config.gradient_accumulation_steps == 0:
# Gradient clipping
if self.scaler:
self.scaler.unscale_(self.optimizer)
self.scaler.unscale_(self.reranker_optimizer)
torch.nn.utils.clip_grad_norm_( # type: ignore[attr-defined]
self.retriever.parameters(),
self.config.max_grad_norm
)
torch.nn.utils.clip_grad_norm_( # type: ignore[attr-defined]
self.reranker.parameters(),
self.config.max_grad_norm
)
# Optimizer steps
if self.scaler:
self.scaler.step(self.optimizer)
self.scaler.step(self.reranker_optimizer)
self.scaler.update()
else:
self.optimizer.step()
self.reranker_optimizer.step()
# Scheduler steps
if self.scheduler:
self.scheduler.step()
# Note: Add reranker scheduler if needed
# Zero gradients
self.optimizer.zero_grad()
self.reranker_optimizer.zero_grad()
self.global_step += 1
# Update hard negatives periodically
if self.global_step % self.update_retriever_interval == 0:
self._update_hard_negatives()
# Logging
if self.global_step % self.config.logging_steps == 0:
step_metrics = {
'retriever_loss': retriever_loss.item() * self.config.gradient_accumulation_steps,
'reranker_loss': reranker_loss.item() * self.config.gradient_accumulation_steps,
'total_loss': (
retriever_loss.item() + reranker_loss.item()) * self.config.gradient_accumulation_steps,
'retriever_lr': self.optimizer.param_groups[0]['lr'],
'reranker_lr': self.reranker_optimizer.param_groups[0]['lr']
}
if self.is_distributed:
step_metrics = reduce_dict(step_metrics)
progress_bar.set_postfix({
'ret_loss': f"{step_metrics['retriever_loss']:.4f}",
'rank_loss': f"{step_metrics['reranker_loss']:.4f}"
})
if is_main_process():
for key, value in step_metrics.items():
self.tb_logger.log_scalar(f"joint/{key}", value, self.global_step)
# Flush logs periodically
if self.global_step % (self.config.logging_steps * 10) == 0:
self.tb_logger.flush()
# Accumulate metrics
epoch_metrics['retriever_loss'] += retriever_loss.item()
epoch_metrics['reranker_loss'] += reranker_loss.item()
epoch_metrics['total_loss'] += (retriever_loss.item() + reranker_loss.item())
num_batches += 1
# Average metrics
for key in epoch_metrics:
epoch_metrics[key] = epoch_metrics[key] / num_batches * self.config.gradient_accumulation_steps
if self.is_distributed:
epoch_metrics = reduce_dict(epoch_metrics)
return epoch_metrics
def evaluate_step(self, batch: Tuple[Dict, Dict]) -> Dict[str, float]:
"""Evaluate both models on a batch"""
retriever_batch, reranker_batch = batch
metrics = {}
# Evaluate retriever
with torch.no_grad():
retriever_loss = self._compute_retriever_loss(retriever_batch)
metrics['retriever_loss'] = retriever_loss.item()
# Compute retrieval metrics
retriever_scores = self._get_retriever_scores(retriever_batch)
labels = retriever_batch['labels']
# Simple accuracy: is the positive passage ranked first?
top_indices = retriever_scores.argmax(dim=-1)
correct = (top_indices == labels.argmax(dim=-1)).float().mean()
metrics['retriever_accuracy'] = correct.item()
# Evaluate reranker
with torch.no_grad():
reranker_loss = self._compute_reranker_loss(reranker_batch)
metrics['reranker_loss'] = reranker_loss.item()
# Compute reranking metrics
reranker_scores = self._get_reranker_scores(reranker_batch)
labels = reranker_batch['labels']
if reranker_scores.dim() == 1:
# Binary accuracy
predictions = (reranker_scores > 0).float()
correct = (predictions == labels).float().mean()
else:
# Ranking accuracy
top_indices = reranker_scores.argmax(dim=-1)
correct = (top_indices == labels.argmax(dim=-1)).float().mean()
metrics['reranker_accuracy'] = correct.item()
metrics['total_loss'] = metrics['retriever_loss'] + metrics['reranker_loss']
return metrics
def _update_hard_negatives(self):
"""Update hard negatives using current models (simplified version)"""
# This would implement the dynamic hard negative mining
# For now, just log that we would update
if is_main_process():
logger.info(f"Would update hard negatives at step {self.global_step}")
def save_models(self, output_dir: str):
"""Save both models"""
if not is_main_process():
return
# Save retriever
retriever_path = os.path.join(output_dir, "retriever")
os.makedirs(retriever_path, exist_ok=True)
if hasattr(self.retriever, 'module'):
self.retriever.module.save_pretrained(retriever_path) # type: ignore[attr-defined]
else:
self.retriever.save_pretrained(retriever_path) # type: ignore[attr-defined]
# Save reranker
reranker_path = os.path.join(output_dir, "reranker")
os.makedirs(reranker_path, exist_ok=True)
if hasattr(self.reranker, 'module'):
self.reranker.module.save_pretrained(reranker_path) # type: ignore[attr-defined]
else:
self.reranker.save_pretrained(reranker_path) # type: ignore[attr-defined]
logger.info(f"Saved models to {output_dir}")