420 lines
16 KiB
Python
420 lines
16 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
from torch.utils.data import DataLoader
|
|
from torch.optim import AdamW
|
|
from typing import Dict, Optional, List, Tuple
|
|
import logging
|
|
from tqdm import tqdm
|
|
import numpy as np
|
|
import os
|
|
|
|
from .trainer import BaseTrainer
|
|
from models.bge_m3 import BGEM3Model
|
|
from models.bge_reranker import BGERerankerModel
|
|
from models.losses import (
|
|
InfoNCELoss, ListwiseCrossEntropyLoss,
|
|
DynamicListwiseDistillationLoss, MultiTaskLoss
|
|
)
|
|
from data.dataset import JointDataset
|
|
from config.base_config import BaseConfig
|
|
from utils.distributed import is_main_process, get_world_size, reduce_dict
|
|
from utils.logging import get_logger
|
|
|
|
# Import autocast and GradScaler with fallback
|
|
from torch.cuda.amp.autocast_mode import autocast
|
|
from torch.cuda.amp.grad_scaler import GradScaler
|
|
# Scheduler import
|
|
try:
|
|
from transformers import get_linear_schedule_with_warmup
|
|
except ImportError:
|
|
get_linear_schedule_with_warmup = None
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
|
|
class JointTrainer(BaseTrainer):
|
|
"""
|
|
Joint trainer for BGE-M3 and BGE-Reranker using RocketQAv2 approach
|
|
|
|
Notes:
|
|
- All parameters (config, retriever, reranker, optimizers, device, etc.) should be passed from config-driven training scripts.
|
|
- This class does not load config directly; pass config values from your main script for consistency.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
config: BaseConfig,
|
|
retriever: BGEM3Model,
|
|
reranker: BGERerankerModel,
|
|
retriever_optimizer: Optional[torch.optim.Optimizer] = None,
|
|
reranker_optimizer: Optional[torch.optim.Optimizer] = None,
|
|
device: Optional[torch.device] = None
|
|
):
|
|
# Create default optimizer if none provided
|
|
if retriever_optimizer is None:
|
|
retriever_optimizer = AdamW(
|
|
retriever.parameters(),
|
|
lr=config.learning_rate,
|
|
weight_decay=config.weight_decay
|
|
)
|
|
|
|
# We'll use the retriever as the main model for the base class
|
|
super().__init__(config, retriever, retriever_optimizer, device=device)
|
|
|
|
self.retriever = self.model # Alias for clarity
|
|
self.reranker = reranker
|
|
self.reranker.to(self.device)
|
|
|
|
# Setup reranker optimizer if not provided
|
|
if reranker_optimizer is None:
|
|
self.reranker_optimizer = AdamW(
|
|
self.reranker.parameters(),
|
|
lr=config.learning_rate * 2, # Higher LR for reranker
|
|
weight_decay=config.weight_decay
|
|
)
|
|
else:
|
|
self.reranker_optimizer = reranker_optimizer
|
|
|
|
# Initialize losses
|
|
self.retriever_loss_fn = InfoNCELoss(
|
|
temperature=config.temperature if hasattr(config, 'temperature') else 0.02,
|
|
use_in_batch_negatives=True # type: ignore[call-arg]
|
|
)
|
|
|
|
self.reranker_loss_fn = ListwiseCrossEntropyLoss(
|
|
temperature=1.0,
|
|
label_smoothing=0.0 # type: ignore[call-arg]
|
|
)
|
|
|
|
self.distillation_loss_fn = DynamicListwiseDistillationLoss(
|
|
temperature=1.0,
|
|
alpha=0.5 # type: ignore[call-arg]
|
|
)
|
|
|
|
# Joint training specific parameters
|
|
self.update_retriever_interval = getattr(config, 'update_retriever_steps', 100)
|
|
self.joint_loss_weight = getattr(config, 'joint_loss_weight', 0.3)
|
|
self.use_dynamic_distillation = getattr(config, 'use_dynamic_listwise_distillation', True)
|
|
|
|
# Tracking metrics
|
|
self.retriever_metrics = []
|
|
self.reranker_metrics = []
|
|
|
|
def compute_loss(self, batch: Tuple[Dict, Dict]) -> torch.Tensor:
|
|
"""
|
|
Compute joint loss for both retriever and reranker
|
|
|
|
Args:
|
|
batch: Tuple of (retriever_batch, reranker_batch)
|
|
"""
|
|
retriever_batch, reranker_batch = batch
|
|
|
|
# 1. Compute retriever loss
|
|
retriever_loss = self._compute_retriever_loss(retriever_batch)
|
|
|
|
# 2. Compute reranker loss
|
|
reranker_loss = self._compute_reranker_loss(reranker_batch)
|
|
|
|
# 3. Compute distillation losses if enabled
|
|
if self.use_dynamic_distillation:
|
|
# Get scores from both models
|
|
with torch.no_grad():
|
|
retriever_scores = self._get_retriever_scores(retriever_batch)
|
|
reranker_scores = self._get_reranker_scores(reranker_batch)
|
|
|
|
# Compute mutual distillation
|
|
retriever_distill_loss, reranker_distill_loss = self.distillation_loss_fn(
|
|
retriever_scores,
|
|
reranker_scores,
|
|
labels=reranker_batch.get('labels')
|
|
)
|
|
|
|
# Add distillation to main losses
|
|
retriever_loss = retriever_loss + self.joint_loss_weight * retriever_distill_loss
|
|
reranker_loss = reranker_loss + self.joint_loss_weight * reranker_distill_loss
|
|
|
|
# Combine losses (we'll handle separate backward passes)
|
|
self.current_retriever_loss = retriever_loss
|
|
self.current_reranker_loss = reranker_loss
|
|
|
|
# Return combined loss for logging
|
|
return retriever_loss + reranker_loss
|
|
|
|
def _compute_retriever_loss(self, batch: Dict) -> torch.Tensor:
|
|
"""Compute retriever (bi-encoder) loss"""
|
|
# Get embeddings
|
|
query_outputs = self.retriever(
|
|
input_ids=batch['query_input_ids'],
|
|
attention_mask=batch['query_attention_mask'],
|
|
return_dense=True
|
|
)
|
|
|
|
passage_outputs = self.retriever(
|
|
input_ids=batch['passage_input_ids'],
|
|
attention_mask=batch['passage_attention_mask'],
|
|
return_dense=True
|
|
)
|
|
|
|
# Compute contrastive loss
|
|
loss = self.retriever_loss_fn(
|
|
query_outputs['dense'],
|
|
passage_outputs['dense'],
|
|
labels=batch.get('labels')
|
|
)
|
|
|
|
return loss
|
|
|
|
def _compute_reranker_loss(self, batch: Dict) -> torch.Tensor:
|
|
"""Compute reranker (cross-encoder) loss"""
|
|
outputs = self.reranker(
|
|
input_ids=batch['input_ids'],
|
|
attention_mask=batch['attention_mask']
|
|
)
|
|
|
|
logits = outputs['logits']
|
|
labels = batch['labels']
|
|
|
|
# Reshape for listwise loss if needed
|
|
if logits.dim() == 1:
|
|
# Binary classification
|
|
loss = nn.functional.binary_cross_entropy_with_logits(logits, labels.float())
|
|
else:
|
|
# Listwise ranking
|
|
loss = self.reranker_loss_fn(logits, labels)
|
|
|
|
return loss
|
|
|
|
def _get_retriever_scores(self, batch: Dict) -> torch.Tensor:
|
|
"""Get retriever similarity scores"""
|
|
with torch.no_grad():
|
|
query_outputs = self.retriever(
|
|
input_ids=batch['query_input_ids'],
|
|
attention_mask=batch['query_attention_mask'],
|
|
return_dense=True
|
|
)
|
|
|
|
passage_outputs = self.retriever(
|
|
input_ids=batch['passage_input_ids'],
|
|
attention_mask=batch['passage_attention_mask'],
|
|
return_dense=True
|
|
)
|
|
|
|
# Compute similarities
|
|
scores = torch.matmul(
|
|
query_outputs['dense'],
|
|
passage_outputs['dense'].transpose(-1, -2)
|
|
)
|
|
|
|
return scores
|
|
|
|
def _get_reranker_scores(self, batch: Dict) -> torch.Tensor:
|
|
"""Get reranker scores"""
|
|
outputs = self.reranker(
|
|
input_ids=batch['input_ids'],
|
|
attention_mask=batch['attention_mask']
|
|
)
|
|
return outputs['logits']
|
|
|
|
def _train_epoch(self, dataloader: DataLoader, epoch: int) -> Dict[str, float]:
|
|
"""Train both models for one epoch"""
|
|
self.retriever.train()
|
|
self.reranker.train()
|
|
|
|
epoch_metrics = {
|
|
'retriever_loss': 0.0,
|
|
'reranker_loss': 0.0,
|
|
'total_loss': 0.0
|
|
}
|
|
num_batches = 0
|
|
|
|
# Set epoch for distributed sampler
|
|
if hasattr(dataloader, 'sampler') and hasattr(dataloader.sampler, 'set_epoch'):
|
|
dataloader.sampler.set_epoch(epoch) # type: ignore[attr-defined]
|
|
|
|
progress_bar = tqdm(
|
|
dataloader,
|
|
desc=f"Joint Training Epoch {epoch + 1}",
|
|
disable=not is_main_process()
|
|
)
|
|
|
|
for step, batch in enumerate(progress_bar):
|
|
# Prepare batch
|
|
retriever_batch, reranker_batch = batch
|
|
retriever_batch = self._prepare_batch(retriever_batch)
|
|
reranker_batch = self._prepare_batch(reranker_batch)
|
|
batch = (retriever_batch, reranker_batch)
|
|
|
|
# Forward pass
|
|
total_loss = self.compute_loss(batch)
|
|
|
|
# Scale losses for gradient accumulation
|
|
retriever_loss = self.current_retriever_loss / self.config.gradient_accumulation_steps
|
|
reranker_loss = self.current_reranker_loss / self.config.gradient_accumulation_steps
|
|
|
|
# Backward passes (separate for each model)
|
|
if self.scaler:
|
|
# Retriever backward
|
|
self.scaler.scale(retriever_loss).backward(retain_graph=True) # type: ignore[attr-defined]
|
|
# Reranker backward
|
|
self.scaler.scale(reranker_loss).backward() # type: ignore[attr-defined]
|
|
else:
|
|
retriever_loss.backward(retain_graph=True)
|
|
reranker_loss.backward()
|
|
|
|
# Update weights
|
|
if (step + 1) % self.config.gradient_accumulation_steps == 0:
|
|
# Gradient clipping
|
|
if self.scaler:
|
|
self.scaler.unscale_(self.optimizer)
|
|
self.scaler.unscale_(self.reranker_optimizer)
|
|
|
|
torch.nn.utils.clip_grad_norm_( # type: ignore[attr-defined]
|
|
self.retriever.parameters(),
|
|
self.config.max_grad_norm
|
|
)
|
|
torch.nn.utils.clip_grad_norm_( # type: ignore[attr-defined]
|
|
self.reranker.parameters(),
|
|
self.config.max_grad_norm
|
|
)
|
|
|
|
# Optimizer steps
|
|
if self.scaler:
|
|
self.scaler.step(self.optimizer)
|
|
self.scaler.step(self.reranker_optimizer)
|
|
self.scaler.update()
|
|
else:
|
|
self.optimizer.step()
|
|
self.reranker_optimizer.step()
|
|
|
|
# Scheduler steps
|
|
if self.scheduler:
|
|
self.scheduler.step()
|
|
# Note: Add reranker scheduler if needed
|
|
|
|
# Zero gradients
|
|
self.optimizer.zero_grad()
|
|
self.reranker_optimizer.zero_grad()
|
|
|
|
self.global_step += 1
|
|
|
|
# Update hard negatives periodically
|
|
if self.global_step % self.update_retriever_interval == 0:
|
|
self._update_hard_negatives()
|
|
|
|
# Logging
|
|
if self.global_step % self.config.logging_steps == 0:
|
|
step_metrics = {
|
|
'retriever_loss': retriever_loss.item() * self.config.gradient_accumulation_steps,
|
|
'reranker_loss': reranker_loss.item() * self.config.gradient_accumulation_steps,
|
|
'total_loss': (
|
|
retriever_loss.item() + reranker_loss.item()) * self.config.gradient_accumulation_steps,
|
|
'retriever_lr': self.optimizer.param_groups[0]['lr'],
|
|
'reranker_lr': self.reranker_optimizer.param_groups[0]['lr']
|
|
}
|
|
|
|
if self.is_distributed:
|
|
step_metrics = reduce_dict(step_metrics)
|
|
|
|
progress_bar.set_postfix({
|
|
'ret_loss': f"{step_metrics['retriever_loss']:.4f}",
|
|
'rank_loss': f"{step_metrics['reranker_loss']:.4f}"
|
|
})
|
|
|
|
if is_main_process():
|
|
for key, value in step_metrics.items():
|
|
self.tb_logger.log_scalar(f"joint/{key}", value, self.global_step)
|
|
# Flush logs periodically
|
|
if self.global_step % (self.config.logging_steps * 10) == 0:
|
|
self.tb_logger.flush()
|
|
|
|
# Accumulate metrics
|
|
epoch_metrics['retriever_loss'] += retriever_loss.item()
|
|
epoch_metrics['reranker_loss'] += reranker_loss.item()
|
|
epoch_metrics['total_loss'] += (retriever_loss.item() + reranker_loss.item())
|
|
num_batches += 1
|
|
|
|
# Average metrics
|
|
for key in epoch_metrics:
|
|
epoch_metrics[key] = epoch_metrics[key] / num_batches * self.config.gradient_accumulation_steps
|
|
|
|
if self.is_distributed:
|
|
epoch_metrics = reduce_dict(epoch_metrics)
|
|
|
|
return epoch_metrics
|
|
|
|
def evaluate_step(self, batch: Tuple[Dict, Dict]) -> Dict[str, float]:
|
|
"""Evaluate both models on a batch"""
|
|
retriever_batch, reranker_batch = batch
|
|
|
|
metrics = {}
|
|
|
|
# Evaluate retriever
|
|
with torch.no_grad():
|
|
retriever_loss = self._compute_retriever_loss(retriever_batch)
|
|
metrics['retriever_loss'] = retriever_loss.item()
|
|
|
|
# Compute retrieval metrics
|
|
retriever_scores = self._get_retriever_scores(retriever_batch)
|
|
labels = retriever_batch['labels']
|
|
|
|
# Simple accuracy: is the positive passage ranked first?
|
|
top_indices = retriever_scores.argmax(dim=-1)
|
|
correct = (top_indices == labels.argmax(dim=-1)).float().mean()
|
|
metrics['retriever_accuracy'] = correct.item()
|
|
|
|
# Evaluate reranker
|
|
with torch.no_grad():
|
|
reranker_loss = self._compute_reranker_loss(reranker_batch)
|
|
metrics['reranker_loss'] = reranker_loss.item()
|
|
|
|
# Compute reranking metrics
|
|
reranker_scores = self._get_reranker_scores(reranker_batch)
|
|
labels = reranker_batch['labels']
|
|
|
|
if reranker_scores.dim() == 1:
|
|
# Binary accuracy
|
|
predictions = (reranker_scores > 0).float()
|
|
correct = (predictions == labels).float().mean()
|
|
else:
|
|
# Ranking accuracy
|
|
top_indices = reranker_scores.argmax(dim=-1)
|
|
correct = (top_indices == labels.argmax(dim=-1)).float().mean()
|
|
|
|
metrics['reranker_accuracy'] = correct.item()
|
|
|
|
metrics['total_loss'] = metrics['retriever_loss'] + metrics['reranker_loss']
|
|
|
|
return metrics
|
|
|
|
def _update_hard_negatives(self):
|
|
"""Update hard negatives using current models (simplified version)"""
|
|
# This would implement the dynamic hard negative mining
|
|
# For now, just log that we would update
|
|
if is_main_process():
|
|
logger.info(f"Would update hard negatives at step {self.global_step}")
|
|
|
|
def save_models(self, output_dir: str):
|
|
"""Save both models"""
|
|
if not is_main_process():
|
|
return
|
|
|
|
# Save retriever
|
|
retriever_path = os.path.join(output_dir, "retriever")
|
|
os.makedirs(retriever_path, exist_ok=True)
|
|
|
|
if hasattr(self.retriever, 'module'):
|
|
self.retriever.module.save_pretrained(retriever_path) # type: ignore[attr-defined]
|
|
else:
|
|
self.retriever.save_pretrained(retriever_path) # type: ignore[attr-defined]
|
|
|
|
# Save reranker
|
|
reranker_path = os.path.join(output_dir, "reranker")
|
|
os.makedirs(reranker_path, exist_ok=True)
|
|
|
|
if hasattr(self.reranker, 'module'):
|
|
self.reranker.module.save_pretrained(reranker_path) # type: ignore[attr-defined]
|
|
else:
|
|
self.reranker.save_pretrained(reranker_path) # type: ignore[attr-defined]
|
|
|
|
logger.info(f"Saved models to {output_dir}")
|