init commit
This commit is contained in:
419
training/joint_trainer.py
Normal file
419
training/joint_trainer.py
Normal file
@@ -0,0 +1,419 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.optim import AdamW
|
||||
from typing import Dict, Optional, List, Tuple
|
||||
import logging
|
||||
from tqdm import tqdm
|
||||
import numpy as np
|
||||
import os
|
||||
|
||||
from .trainer import BaseTrainer
|
||||
from models.bge_m3 import BGEM3Model
|
||||
from models.bge_reranker import BGERerankerModel
|
||||
from models.losses import (
|
||||
InfoNCELoss, ListwiseCrossEntropyLoss,
|
||||
DynamicListwiseDistillationLoss, MultiTaskLoss
|
||||
)
|
||||
from data.dataset import JointDataset
|
||||
from config.base_config import BaseConfig
|
||||
from utils.distributed import is_main_process, get_world_size, reduce_dict
|
||||
from utils.logging import get_logger
|
||||
|
||||
# Import autocast and GradScaler with fallback
|
||||
from torch.cuda.amp.autocast_mode import autocast
|
||||
from torch.cuda.amp.grad_scaler import GradScaler
|
||||
# Scheduler import
|
||||
try:
|
||||
from transformers import get_linear_schedule_with_warmup
|
||||
except ImportError:
|
||||
get_linear_schedule_with_warmup = None
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class JointTrainer(BaseTrainer):
|
||||
"""
|
||||
Joint trainer for BGE-M3 and BGE-Reranker using RocketQAv2 approach
|
||||
|
||||
Notes:
|
||||
- All parameters (config, retriever, reranker, optimizers, device, etc.) should be passed from config-driven training scripts.
|
||||
- This class does not load config directly; pass config values from your main script for consistency.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: BaseConfig,
|
||||
retriever: BGEM3Model,
|
||||
reranker: BGERerankerModel,
|
||||
retriever_optimizer: Optional[torch.optim.Optimizer] = None,
|
||||
reranker_optimizer: Optional[torch.optim.Optimizer] = None,
|
||||
device: Optional[torch.device] = None
|
||||
):
|
||||
# Create default optimizer if none provided
|
||||
if retriever_optimizer is None:
|
||||
retriever_optimizer = AdamW(
|
||||
retriever.parameters(),
|
||||
lr=config.learning_rate,
|
||||
weight_decay=config.weight_decay
|
||||
)
|
||||
|
||||
# We'll use the retriever as the main model for the base class
|
||||
super().__init__(config, retriever, retriever_optimizer, device=device)
|
||||
|
||||
self.retriever = self.model # Alias for clarity
|
||||
self.reranker = reranker
|
||||
self.reranker.to(self.device)
|
||||
|
||||
# Setup reranker optimizer if not provided
|
||||
if reranker_optimizer is None:
|
||||
self.reranker_optimizer = AdamW(
|
||||
self.reranker.parameters(),
|
||||
lr=config.learning_rate * 2, # Higher LR for reranker
|
||||
weight_decay=config.weight_decay
|
||||
)
|
||||
else:
|
||||
self.reranker_optimizer = reranker_optimizer
|
||||
|
||||
# Initialize losses
|
||||
self.retriever_loss_fn = InfoNCELoss(
|
||||
temperature=config.temperature if hasattr(config, 'temperature') else 0.02,
|
||||
use_in_batch_negatives=True # type: ignore[call-arg]
|
||||
)
|
||||
|
||||
self.reranker_loss_fn = ListwiseCrossEntropyLoss(
|
||||
temperature=1.0,
|
||||
label_smoothing=0.0 # type: ignore[call-arg]
|
||||
)
|
||||
|
||||
self.distillation_loss_fn = DynamicListwiseDistillationLoss(
|
||||
temperature=1.0,
|
||||
alpha=0.5 # type: ignore[call-arg]
|
||||
)
|
||||
|
||||
# Joint training specific parameters
|
||||
self.update_retriever_interval = getattr(config, 'update_retriever_steps', 100)
|
||||
self.joint_loss_weight = getattr(config, 'joint_loss_weight', 0.3)
|
||||
self.use_dynamic_distillation = getattr(config, 'use_dynamic_listwise_distillation', True)
|
||||
|
||||
# Tracking metrics
|
||||
self.retriever_metrics = []
|
||||
self.reranker_metrics = []
|
||||
|
||||
def compute_loss(self, batch: Tuple[Dict, Dict]) -> torch.Tensor:
|
||||
"""
|
||||
Compute joint loss for both retriever and reranker
|
||||
|
||||
Args:
|
||||
batch: Tuple of (retriever_batch, reranker_batch)
|
||||
"""
|
||||
retriever_batch, reranker_batch = batch
|
||||
|
||||
# 1. Compute retriever loss
|
||||
retriever_loss = self._compute_retriever_loss(retriever_batch)
|
||||
|
||||
# 2. Compute reranker loss
|
||||
reranker_loss = self._compute_reranker_loss(reranker_batch)
|
||||
|
||||
# 3. Compute distillation losses if enabled
|
||||
if self.use_dynamic_distillation:
|
||||
# Get scores from both models
|
||||
with torch.no_grad():
|
||||
retriever_scores = self._get_retriever_scores(retriever_batch)
|
||||
reranker_scores = self._get_reranker_scores(reranker_batch)
|
||||
|
||||
# Compute mutual distillation
|
||||
retriever_distill_loss, reranker_distill_loss = self.distillation_loss_fn(
|
||||
retriever_scores,
|
||||
reranker_scores,
|
||||
labels=reranker_batch.get('labels')
|
||||
)
|
||||
|
||||
# Add distillation to main losses
|
||||
retriever_loss = retriever_loss + self.joint_loss_weight * retriever_distill_loss
|
||||
reranker_loss = reranker_loss + self.joint_loss_weight * reranker_distill_loss
|
||||
|
||||
# Combine losses (we'll handle separate backward passes)
|
||||
self.current_retriever_loss = retriever_loss
|
||||
self.current_reranker_loss = reranker_loss
|
||||
|
||||
# Return combined loss for logging
|
||||
return retriever_loss + reranker_loss
|
||||
|
||||
def _compute_retriever_loss(self, batch: Dict) -> torch.Tensor:
|
||||
"""Compute retriever (bi-encoder) loss"""
|
||||
# Get embeddings
|
||||
query_outputs = self.retriever(
|
||||
input_ids=batch['query_input_ids'],
|
||||
attention_mask=batch['query_attention_mask'],
|
||||
return_dense=True
|
||||
)
|
||||
|
||||
passage_outputs = self.retriever(
|
||||
input_ids=batch['passage_input_ids'],
|
||||
attention_mask=batch['passage_attention_mask'],
|
||||
return_dense=True
|
||||
)
|
||||
|
||||
# Compute contrastive loss
|
||||
loss = self.retriever_loss_fn(
|
||||
query_outputs['dense'],
|
||||
passage_outputs['dense'],
|
||||
labels=batch.get('labels')
|
||||
)
|
||||
|
||||
return loss
|
||||
|
||||
def _compute_reranker_loss(self, batch: Dict) -> torch.Tensor:
|
||||
"""Compute reranker (cross-encoder) loss"""
|
||||
outputs = self.reranker(
|
||||
input_ids=batch['input_ids'],
|
||||
attention_mask=batch['attention_mask']
|
||||
)
|
||||
|
||||
logits = outputs['logits']
|
||||
labels = batch['labels']
|
||||
|
||||
# Reshape for listwise loss if needed
|
||||
if logits.dim() == 1:
|
||||
# Binary classification
|
||||
loss = nn.functional.binary_cross_entropy_with_logits(logits, labels.float())
|
||||
else:
|
||||
# Listwise ranking
|
||||
loss = self.reranker_loss_fn(logits, labels)
|
||||
|
||||
return loss
|
||||
|
||||
def _get_retriever_scores(self, batch: Dict) -> torch.Tensor:
|
||||
"""Get retriever similarity scores"""
|
||||
with torch.no_grad():
|
||||
query_outputs = self.retriever(
|
||||
input_ids=batch['query_input_ids'],
|
||||
attention_mask=batch['query_attention_mask'],
|
||||
return_dense=True
|
||||
)
|
||||
|
||||
passage_outputs = self.retriever(
|
||||
input_ids=batch['passage_input_ids'],
|
||||
attention_mask=batch['passage_attention_mask'],
|
||||
return_dense=True
|
||||
)
|
||||
|
||||
# Compute similarities
|
||||
scores = torch.matmul(
|
||||
query_outputs['dense'],
|
||||
passage_outputs['dense'].transpose(-1, -2)
|
||||
)
|
||||
|
||||
return scores
|
||||
|
||||
def _get_reranker_scores(self, batch: Dict) -> torch.Tensor:
|
||||
"""Get reranker scores"""
|
||||
outputs = self.reranker(
|
||||
input_ids=batch['input_ids'],
|
||||
attention_mask=batch['attention_mask']
|
||||
)
|
||||
return outputs['logits']
|
||||
|
||||
def _train_epoch(self, dataloader: DataLoader, epoch: int) -> Dict[str, float]:
|
||||
"""Train both models for one epoch"""
|
||||
self.retriever.train()
|
||||
self.reranker.train()
|
||||
|
||||
epoch_metrics = {
|
||||
'retriever_loss': 0.0,
|
||||
'reranker_loss': 0.0,
|
||||
'total_loss': 0.0
|
||||
}
|
||||
num_batches = 0
|
||||
|
||||
# Set epoch for distributed sampler
|
||||
if hasattr(dataloader, 'sampler') and hasattr(dataloader.sampler, 'set_epoch'):
|
||||
dataloader.sampler.set_epoch(epoch) # type: ignore[attr-defined]
|
||||
|
||||
progress_bar = tqdm(
|
||||
dataloader,
|
||||
desc=f"Joint Training Epoch {epoch + 1}",
|
||||
disable=not is_main_process()
|
||||
)
|
||||
|
||||
for step, batch in enumerate(progress_bar):
|
||||
# Prepare batch
|
||||
retriever_batch, reranker_batch = batch
|
||||
retriever_batch = self._prepare_batch(retriever_batch)
|
||||
reranker_batch = self._prepare_batch(reranker_batch)
|
||||
batch = (retriever_batch, reranker_batch)
|
||||
|
||||
# Forward pass
|
||||
total_loss = self.compute_loss(batch)
|
||||
|
||||
# Scale losses for gradient accumulation
|
||||
retriever_loss = self.current_retriever_loss / self.config.gradient_accumulation_steps
|
||||
reranker_loss = self.current_reranker_loss / self.config.gradient_accumulation_steps
|
||||
|
||||
# Backward passes (separate for each model)
|
||||
if self.scaler:
|
||||
# Retriever backward
|
||||
self.scaler.scale(retriever_loss).backward(retain_graph=True) # type: ignore[attr-defined]
|
||||
# Reranker backward
|
||||
self.scaler.scale(reranker_loss).backward() # type: ignore[attr-defined]
|
||||
else:
|
||||
retriever_loss.backward(retain_graph=True)
|
||||
reranker_loss.backward()
|
||||
|
||||
# Update weights
|
||||
if (step + 1) % self.config.gradient_accumulation_steps == 0:
|
||||
# Gradient clipping
|
||||
if self.scaler:
|
||||
self.scaler.unscale_(self.optimizer)
|
||||
self.scaler.unscale_(self.reranker_optimizer)
|
||||
|
||||
torch.nn.utils.clip_grad_norm_( # type: ignore[attr-defined]
|
||||
self.retriever.parameters(),
|
||||
self.config.max_grad_norm
|
||||
)
|
||||
torch.nn.utils.clip_grad_norm_( # type: ignore[attr-defined]
|
||||
self.reranker.parameters(),
|
||||
self.config.max_grad_norm
|
||||
)
|
||||
|
||||
# Optimizer steps
|
||||
if self.scaler:
|
||||
self.scaler.step(self.optimizer)
|
||||
self.scaler.step(self.reranker_optimizer)
|
||||
self.scaler.update()
|
||||
else:
|
||||
self.optimizer.step()
|
||||
self.reranker_optimizer.step()
|
||||
|
||||
# Scheduler steps
|
||||
if self.scheduler:
|
||||
self.scheduler.step()
|
||||
# Note: Add reranker scheduler if needed
|
||||
|
||||
# Zero gradients
|
||||
self.optimizer.zero_grad()
|
||||
self.reranker_optimizer.zero_grad()
|
||||
|
||||
self.global_step += 1
|
||||
|
||||
# Update hard negatives periodically
|
||||
if self.global_step % self.update_retriever_interval == 0:
|
||||
self._update_hard_negatives()
|
||||
|
||||
# Logging
|
||||
if self.global_step % self.config.logging_steps == 0:
|
||||
step_metrics = {
|
||||
'retriever_loss': retriever_loss.item() * self.config.gradient_accumulation_steps,
|
||||
'reranker_loss': reranker_loss.item() * self.config.gradient_accumulation_steps,
|
||||
'total_loss': (
|
||||
retriever_loss.item() + reranker_loss.item()) * self.config.gradient_accumulation_steps,
|
||||
'retriever_lr': self.optimizer.param_groups[0]['lr'],
|
||||
'reranker_lr': self.reranker_optimizer.param_groups[0]['lr']
|
||||
}
|
||||
|
||||
if self.is_distributed:
|
||||
step_metrics = reduce_dict(step_metrics)
|
||||
|
||||
progress_bar.set_postfix({
|
||||
'ret_loss': f"{step_metrics['retriever_loss']:.4f}",
|
||||
'rank_loss': f"{step_metrics['reranker_loss']:.4f}"
|
||||
})
|
||||
|
||||
if is_main_process():
|
||||
for key, value in step_metrics.items():
|
||||
self.tb_logger.log_scalar(f"joint/{key}", value, self.global_step)
|
||||
# Flush logs periodically
|
||||
if self.global_step % (self.config.logging_steps * 10) == 0:
|
||||
self.tb_logger.flush()
|
||||
|
||||
# Accumulate metrics
|
||||
epoch_metrics['retriever_loss'] += retriever_loss.item()
|
||||
epoch_metrics['reranker_loss'] += reranker_loss.item()
|
||||
epoch_metrics['total_loss'] += (retriever_loss.item() + reranker_loss.item())
|
||||
num_batches += 1
|
||||
|
||||
# Average metrics
|
||||
for key in epoch_metrics:
|
||||
epoch_metrics[key] = epoch_metrics[key] / num_batches * self.config.gradient_accumulation_steps
|
||||
|
||||
if self.is_distributed:
|
||||
epoch_metrics = reduce_dict(epoch_metrics)
|
||||
|
||||
return epoch_metrics
|
||||
|
||||
def evaluate_step(self, batch: Tuple[Dict, Dict]) -> Dict[str, float]:
|
||||
"""Evaluate both models on a batch"""
|
||||
retriever_batch, reranker_batch = batch
|
||||
|
||||
metrics = {}
|
||||
|
||||
# Evaluate retriever
|
||||
with torch.no_grad():
|
||||
retriever_loss = self._compute_retriever_loss(retriever_batch)
|
||||
metrics['retriever_loss'] = retriever_loss.item()
|
||||
|
||||
# Compute retrieval metrics
|
||||
retriever_scores = self._get_retriever_scores(retriever_batch)
|
||||
labels = retriever_batch['labels']
|
||||
|
||||
# Simple accuracy: is the positive passage ranked first?
|
||||
top_indices = retriever_scores.argmax(dim=-1)
|
||||
correct = (top_indices == labels.argmax(dim=-1)).float().mean()
|
||||
metrics['retriever_accuracy'] = correct.item()
|
||||
|
||||
# Evaluate reranker
|
||||
with torch.no_grad():
|
||||
reranker_loss = self._compute_reranker_loss(reranker_batch)
|
||||
metrics['reranker_loss'] = reranker_loss.item()
|
||||
|
||||
# Compute reranking metrics
|
||||
reranker_scores = self._get_reranker_scores(reranker_batch)
|
||||
labels = reranker_batch['labels']
|
||||
|
||||
if reranker_scores.dim() == 1:
|
||||
# Binary accuracy
|
||||
predictions = (reranker_scores > 0).float()
|
||||
correct = (predictions == labels).float().mean()
|
||||
else:
|
||||
# Ranking accuracy
|
||||
top_indices = reranker_scores.argmax(dim=-1)
|
||||
correct = (top_indices == labels.argmax(dim=-1)).float().mean()
|
||||
|
||||
metrics['reranker_accuracy'] = correct.item()
|
||||
|
||||
metrics['total_loss'] = metrics['retriever_loss'] + metrics['reranker_loss']
|
||||
|
||||
return metrics
|
||||
|
||||
def _update_hard_negatives(self):
|
||||
"""Update hard negatives using current models (simplified version)"""
|
||||
# This would implement the dynamic hard negative mining
|
||||
# For now, just log that we would update
|
||||
if is_main_process():
|
||||
logger.info(f"Would update hard negatives at step {self.global_step}")
|
||||
|
||||
def save_models(self, output_dir: str):
|
||||
"""Save both models"""
|
||||
if not is_main_process():
|
||||
return
|
||||
|
||||
# Save retriever
|
||||
retriever_path = os.path.join(output_dir, "retriever")
|
||||
os.makedirs(retriever_path, exist_ok=True)
|
||||
|
||||
if hasattr(self.retriever, 'module'):
|
||||
self.retriever.module.save_pretrained(retriever_path) # type: ignore[attr-defined]
|
||||
else:
|
||||
self.retriever.save_pretrained(retriever_path) # type: ignore[attr-defined]
|
||||
|
||||
# Save reranker
|
||||
reranker_path = os.path.join(output_dir, "reranker")
|
||||
os.makedirs(reranker_path, exist_ok=True)
|
||||
|
||||
if hasattr(self.reranker, 'module'):
|
||||
self.reranker.module.save_pretrained(reranker_path) # type: ignore[attr-defined]
|
||||
else:
|
||||
self.reranker.save_pretrained(reranker_path) # type: ignore[attr-defined]
|
||||
|
||||
logger.info(f"Saved models to {output_dir}")
|
||||
Reference in New Issue
Block a user