init commit

This commit is contained in:
ldy
2025-07-22 16:55:25 +08:00
commit 36003b83e2
67 changed files with 76613 additions and 0 deletions

40
training/__init__.py Normal file
View File

@@ -0,0 +1,40 @@
"""
Training modules for BGE models
"""
def get_trainer(trainer_type: str):
"""Lazy loading of trainers to avoid importing all at once"""
if trainer_type == 'base':
from .trainer import BaseTrainer
return BaseTrainer
elif trainer_type == 'm3':
from .m3_trainer import BGEM3Trainer
return BGEM3Trainer
elif trainer_type == 'reranker':
from .reranker_trainer import BGERerankerTrainer
return BGERerankerTrainer
elif trainer_type == 'joint':
from .joint_trainer import JointTrainer
return JointTrainer
else:
raise ValueError(f"Unknown trainer type: {trainer_type}")
# For backward compatibility, import the commonly used ones
from .trainer import BaseTrainer
from .m3_trainer import BGEM3Trainer
# Only import reranker and joint trainers when explicitly requested
def __getattr__(name):
if name == 'BGERerankerTrainer':
from .reranker_trainer import BGERerankerTrainer
return BGERerankerTrainer
elif name == 'JointTrainer':
from .joint_trainer import JointTrainer
return JointTrainer
raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
__all__ = [
'BaseTrainer',
'BGEM3Trainer',
'get_trainer'
]

419
training/joint_trainer.py Normal file
View File

@@ -0,0 +1,419 @@
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.optim import AdamW
from typing import Dict, Optional, List, Tuple
import logging
from tqdm import tqdm
import numpy as np
import os
from .trainer import BaseTrainer
from models.bge_m3 import BGEM3Model
from models.bge_reranker import BGERerankerModel
from models.losses import (
InfoNCELoss, ListwiseCrossEntropyLoss,
DynamicListwiseDistillationLoss, MultiTaskLoss
)
from data.dataset import JointDataset
from config.base_config import BaseConfig
from utils.distributed import is_main_process, get_world_size, reduce_dict
from utils.logging import get_logger
# Import autocast and GradScaler with fallback
from torch.cuda.amp.autocast_mode import autocast
from torch.cuda.amp.grad_scaler import GradScaler
# Scheduler import
try:
from transformers import get_linear_schedule_with_warmup
except ImportError:
get_linear_schedule_with_warmup = None
logger = get_logger(__name__)
class JointTrainer(BaseTrainer):
"""
Joint trainer for BGE-M3 and BGE-Reranker using RocketQAv2 approach
Notes:
- All parameters (config, retriever, reranker, optimizers, device, etc.) should be passed from config-driven training scripts.
- This class does not load config directly; pass config values from your main script for consistency.
"""
def __init__(
self,
config: BaseConfig,
retriever: BGEM3Model,
reranker: BGERerankerModel,
retriever_optimizer: Optional[torch.optim.Optimizer] = None,
reranker_optimizer: Optional[torch.optim.Optimizer] = None,
device: Optional[torch.device] = None
):
# Create default optimizer if none provided
if retriever_optimizer is None:
retriever_optimizer = AdamW(
retriever.parameters(),
lr=config.learning_rate,
weight_decay=config.weight_decay
)
# We'll use the retriever as the main model for the base class
super().__init__(config, retriever, retriever_optimizer, device=device)
self.retriever = self.model # Alias for clarity
self.reranker = reranker
self.reranker.to(self.device)
# Setup reranker optimizer if not provided
if reranker_optimizer is None:
self.reranker_optimizer = AdamW(
self.reranker.parameters(),
lr=config.learning_rate * 2, # Higher LR for reranker
weight_decay=config.weight_decay
)
else:
self.reranker_optimizer = reranker_optimizer
# Initialize losses
self.retriever_loss_fn = InfoNCELoss(
temperature=config.temperature if hasattr(config, 'temperature') else 0.02,
use_in_batch_negatives=True # type: ignore[call-arg]
)
self.reranker_loss_fn = ListwiseCrossEntropyLoss(
temperature=1.0,
label_smoothing=0.0 # type: ignore[call-arg]
)
self.distillation_loss_fn = DynamicListwiseDistillationLoss(
temperature=1.0,
alpha=0.5 # type: ignore[call-arg]
)
# Joint training specific parameters
self.update_retriever_interval = getattr(config, 'update_retriever_steps', 100)
self.joint_loss_weight = getattr(config, 'joint_loss_weight', 0.3)
self.use_dynamic_distillation = getattr(config, 'use_dynamic_listwise_distillation', True)
# Tracking metrics
self.retriever_metrics = []
self.reranker_metrics = []
def compute_loss(self, batch: Tuple[Dict, Dict]) -> torch.Tensor:
"""
Compute joint loss for both retriever and reranker
Args:
batch: Tuple of (retriever_batch, reranker_batch)
"""
retriever_batch, reranker_batch = batch
# 1. Compute retriever loss
retriever_loss = self._compute_retriever_loss(retriever_batch)
# 2. Compute reranker loss
reranker_loss = self._compute_reranker_loss(reranker_batch)
# 3. Compute distillation losses if enabled
if self.use_dynamic_distillation:
# Get scores from both models
with torch.no_grad():
retriever_scores = self._get_retriever_scores(retriever_batch)
reranker_scores = self._get_reranker_scores(reranker_batch)
# Compute mutual distillation
retriever_distill_loss, reranker_distill_loss = self.distillation_loss_fn(
retriever_scores,
reranker_scores,
labels=reranker_batch.get('labels')
)
# Add distillation to main losses
retriever_loss = retriever_loss + self.joint_loss_weight * retriever_distill_loss
reranker_loss = reranker_loss + self.joint_loss_weight * reranker_distill_loss
# Combine losses (we'll handle separate backward passes)
self.current_retriever_loss = retriever_loss
self.current_reranker_loss = reranker_loss
# Return combined loss for logging
return retriever_loss + reranker_loss
def _compute_retriever_loss(self, batch: Dict) -> torch.Tensor:
"""Compute retriever (bi-encoder) loss"""
# Get embeddings
query_outputs = self.retriever(
input_ids=batch['query_input_ids'],
attention_mask=batch['query_attention_mask'],
return_dense=True
)
passage_outputs = self.retriever(
input_ids=batch['passage_input_ids'],
attention_mask=batch['passage_attention_mask'],
return_dense=True
)
# Compute contrastive loss
loss = self.retriever_loss_fn(
query_outputs['dense'],
passage_outputs['dense'],
labels=batch.get('labels')
)
return loss
def _compute_reranker_loss(self, batch: Dict) -> torch.Tensor:
"""Compute reranker (cross-encoder) loss"""
outputs = self.reranker(
input_ids=batch['input_ids'],
attention_mask=batch['attention_mask']
)
logits = outputs['logits']
labels = batch['labels']
# Reshape for listwise loss if needed
if logits.dim() == 1:
# Binary classification
loss = nn.functional.binary_cross_entropy_with_logits(logits, labels.float())
else:
# Listwise ranking
loss = self.reranker_loss_fn(logits, labels)
return loss
def _get_retriever_scores(self, batch: Dict) -> torch.Tensor:
"""Get retriever similarity scores"""
with torch.no_grad():
query_outputs = self.retriever(
input_ids=batch['query_input_ids'],
attention_mask=batch['query_attention_mask'],
return_dense=True
)
passage_outputs = self.retriever(
input_ids=batch['passage_input_ids'],
attention_mask=batch['passage_attention_mask'],
return_dense=True
)
# Compute similarities
scores = torch.matmul(
query_outputs['dense'],
passage_outputs['dense'].transpose(-1, -2)
)
return scores
def _get_reranker_scores(self, batch: Dict) -> torch.Tensor:
"""Get reranker scores"""
outputs = self.reranker(
input_ids=batch['input_ids'],
attention_mask=batch['attention_mask']
)
return outputs['logits']
def _train_epoch(self, dataloader: DataLoader, epoch: int) -> Dict[str, float]:
"""Train both models for one epoch"""
self.retriever.train()
self.reranker.train()
epoch_metrics = {
'retriever_loss': 0.0,
'reranker_loss': 0.0,
'total_loss': 0.0
}
num_batches = 0
# Set epoch for distributed sampler
if hasattr(dataloader, 'sampler') and hasattr(dataloader.sampler, 'set_epoch'):
dataloader.sampler.set_epoch(epoch) # type: ignore[attr-defined]
progress_bar = tqdm(
dataloader,
desc=f"Joint Training Epoch {epoch + 1}",
disable=not is_main_process()
)
for step, batch in enumerate(progress_bar):
# Prepare batch
retriever_batch, reranker_batch = batch
retriever_batch = self._prepare_batch(retriever_batch)
reranker_batch = self._prepare_batch(reranker_batch)
batch = (retriever_batch, reranker_batch)
# Forward pass
total_loss = self.compute_loss(batch)
# Scale losses for gradient accumulation
retriever_loss = self.current_retriever_loss / self.config.gradient_accumulation_steps
reranker_loss = self.current_reranker_loss / self.config.gradient_accumulation_steps
# Backward passes (separate for each model)
if self.scaler:
# Retriever backward
self.scaler.scale(retriever_loss).backward(retain_graph=True) # type: ignore[attr-defined]
# Reranker backward
self.scaler.scale(reranker_loss).backward() # type: ignore[attr-defined]
else:
retriever_loss.backward(retain_graph=True)
reranker_loss.backward()
# Update weights
if (step + 1) % self.config.gradient_accumulation_steps == 0:
# Gradient clipping
if self.scaler:
self.scaler.unscale_(self.optimizer)
self.scaler.unscale_(self.reranker_optimizer)
torch.nn.utils.clip_grad_norm_( # type: ignore[attr-defined]
self.retriever.parameters(),
self.config.max_grad_norm
)
torch.nn.utils.clip_grad_norm_( # type: ignore[attr-defined]
self.reranker.parameters(),
self.config.max_grad_norm
)
# Optimizer steps
if self.scaler:
self.scaler.step(self.optimizer)
self.scaler.step(self.reranker_optimizer)
self.scaler.update()
else:
self.optimizer.step()
self.reranker_optimizer.step()
# Scheduler steps
if self.scheduler:
self.scheduler.step()
# Note: Add reranker scheduler if needed
# Zero gradients
self.optimizer.zero_grad()
self.reranker_optimizer.zero_grad()
self.global_step += 1
# Update hard negatives periodically
if self.global_step % self.update_retriever_interval == 0:
self._update_hard_negatives()
# Logging
if self.global_step % self.config.logging_steps == 0:
step_metrics = {
'retriever_loss': retriever_loss.item() * self.config.gradient_accumulation_steps,
'reranker_loss': reranker_loss.item() * self.config.gradient_accumulation_steps,
'total_loss': (
retriever_loss.item() + reranker_loss.item()) * self.config.gradient_accumulation_steps,
'retriever_lr': self.optimizer.param_groups[0]['lr'],
'reranker_lr': self.reranker_optimizer.param_groups[0]['lr']
}
if self.is_distributed:
step_metrics = reduce_dict(step_metrics)
progress_bar.set_postfix({
'ret_loss': f"{step_metrics['retriever_loss']:.4f}",
'rank_loss': f"{step_metrics['reranker_loss']:.4f}"
})
if is_main_process():
for key, value in step_metrics.items():
self.tb_logger.log_scalar(f"joint/{key}", value, self.global_step)
# Flush logs periodically
if self.global_step % (self.config.logging_steps * 10) == 0:
self.tb_logger.flush()
# Accumulate metrics
epoch_metrics['retriever_loss'] += retriever_loss.item()
epoch_metrics['reranker_loss'] += reranker_loss.item()
epoch_metrics['total_loss'] += (retriever_loss.item() + reranker_loss.item())
num_batches += 1
# Average metrics
for key in epoch_metrics:
epoch_metrics[key] = epoch_metrics[key] / num_batches * self.config.gradient_accumulation_steps
if self.is_distributed:
epoch_metrics = reduce_dict(epoch_metrics)
return epoch_metrics
def evaluate_step(self, batch: Tuple[Dict, Dict]) -> Dict[str, float]:
"""Evaluate both models on a batch"""
retriever_batch, reranker_batch = batch
metrics = {}
# Evaluate retriever
with torch.no_grad():
retriever_loss = self._compute_retriever_loss(retriever_batch)
metrics['retriever_loss'] = retriever_loss.item()
# Compute retrieval metrics
retriever_scores = self._get_retriever_scores(retriever_batch)
labels = retriever_batch['labels']
# Simple accuracy: is the positive passage ranked first?
top_indices = retriever_scores.argmax(dim=-1)
correct = (top_indices == labels.argmax(dim=-1)).float().mean()
metrics['retriever_accuracy'] = correct.item()
# Evaluate reranker
with torch.no_grad():
reranker_loss = self._compute_reranker_loss(reranker_batch)
metrics['reranker_loss'] = reranker_loss.item()
# Compute reranking metrics
reranker_scores = self._get_reranker_scores(reranker_batch)
labels = reranker_batch['labels']
if reranker_scores.dim() == 1:
# Binary accuracy
predictions = (reranker_scores > 0).float()
correct = (predictions == labels).float().mean()
else:
# Ranking accuracy
top_indices = reranker_scores.argmax(dim=-1)
correct = (top_indices == labels.argmax(dim=-1)).float().mean()
metrics['reranker_accuracy'] = correct.item()
metrics['total_loss'] = metrics['retriever_loss'] + metrics['reranker_loss']
return metrics
def _update_hard_negatives(self):
"""Update hard negatives using current models (simplified version)"""
# This would implement the dynamic hard negative mining
# For now, just log that we would update
if is_main_process():
logger.info(f"Would update hard negatives at step {self.global_step}")
def save_models(self, output_dir: str):
"""Save both models"""
if not is_main_process():
return
# Save retriever
retriever_path = os.path.join(output_dir, "retriever")
os.makedirs(retriever_path, exist_ok=True)
if hasattr(self.retriever, 'module'):
self.retriever.module.save_pretrained(retriever_path) # type: ignore[attr-defined]
else:
self.retriever.save_pretrained(retriever_path) # type: ignore[attr-defined]
# Save reranker
reranker_path = os.path.join(output_dir, "reranker")
os.makedirs(reranker_path, exist_ok=True)
if hasattr(self.reranker, 'module'):
self.reranker.module.save_pretrained(reranker_path) # type: ignore[attr-defined]
else:
self.reranker.save_pretrained(reranker_path) # type: ignore[attr-defined]
logger.info(f"Saved models to {output_dir}")

894
training/m3_trainer.py Normal file
View File

@@ -0,0 +1,894 @@
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.optim import AdamW
from transformers import AutoTokenizer
import os
import json
import logging
from tqdm import tqdm
from typing import Dict, Optional, List, Tuple
import numpy as np
from torch.cuda.amp.autocast_mode import autocast
from torch.cuda.amp.grad_scaler import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP
import torch.distributed as dist
from .trainer import BaseTrainer
from models.bge_m3 import BGEM3Model
from models.losses import InfoNCELoss, M3KnowledgeDistillationLoss, MultiTaskLoss
from data.dataset import BGEM3Dataset
# Optional hard negative mining
try:
from data.hard_negative_mining import ANCEHardNegativeMiner
ANCE_AVAILABLE = True
except ImportError:
ANCE_AVAILABLE = False
ANCEHardNegativeMiner = None
from utils.checkpoint import CheckpointManager
from utils.logging import get_logger
from evaluation.metrics import compute_retrieval_metrics
from config.m3_config import BGEM3Config
# Import scheduler with fallback
try:
from transformers import get_linear_schedule_with_warmup
except ImportError:
get_linear_schedule_with_warmup = None
logger = get_logger(__name__)
class BGEM3Trainer(BaseTrainer):
"""
Trainer for BGE-M3 embedding model with all SOTA techniques
Notes:
- All parameters (config, model, tokenizer, etc.) should be passed from config-driven training scripts.
- This class does not load config directly; pass config values from your main script for consistency.
"""
def __init__(
self,
config: BGEM3Config,
model: Optional[BGEM3Model] = None,
tokenizer: Optional[AutoTokenizer] = None,
optimizer: Optional[torch.optim.Optimizer] = None,
scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
device: Optional[torch.device] = None
):
# Store config first
self.config = config
# Initialize tokenizer first
if tokenizer is None:
self.tokenizer = AutoTokenizer.from_pretrained(
config.model_name_or_path,
cache_dir=config.cache_dir,
trust_remote_code=True
)
else:
self.tokenizer = tokenizer
# Initialize model
if model is None:
self.model = BGEM3Model(
model_name_or_path=config.model_name_or_path,
use_dense=config.use_dense,
use_sparse=config.use_sparse,
use_colbert=config.use_colbert,
pooling_method=config.sentence_pooling_method,
normalize_embeddings=config.normalize_embeddings,
colbert_dim=config.colbert_dim,
sparse_top_k=config.sparse_top_k,
cache_dir=config.cache_dir
)
else:
self.model = model
# Initialize optimizer if not provided
if optimizer is None:
optimizer = self._create_optimizer()
# Initialize scheduler if not provided and get_linear_schedule_with_warmup is available
if scheduler is None and get_linear_schedule_with_warmup is not None:
# We'll create it later when we know the total training steps
pass
# Initialize parent class
super().__init__(config, self.model, optimizer, scheduler, device)
# Initialize losses
self.infonce_loss = InfoNCELoss(
temperature=config.temperature,
reduction='mean'
)
self.kd_loss = M3KnowledgeDistillationLoss(
temperature=config.self_distill_temperature,
alpha=config.self_distill_alpha,
distill_types=['dense', 'sparse', 'colbert']
)
self.multi_task_loss = MultiTaskLoss(
loss_weights={
'dense': config.dense_weight,
'sparse': config.sparse_weight,
'colbert': config.colbert_weight,
'distillation': config.distillation_weight
}
)
# Initialize hard negative miner
if config.use_hard_negatives and ANCE_AVAILABLE:
self.hard_negative_miner = ANCEHardNegativeMiner(
embedding_dim=self.model.config.hidden_size, # type: ignore[attr-defined]
num_hard_negatives=config.num_hard_negatives,
negative_range=config.ance_negative_range,
margin=config.ance_margin,
index_refresh_interval=config.hard_negative_mining_steps,
use_gpu=str(self.device) == "cuda"
)
self.hard_negative_miner.start_async_updates()
else:
if config.use_hard_negatives and not ANCE_AVAILABLE:
logger.warning("Hard negative mining requested but ANCEHardNegativeMiner not available")
self.hard_negative_miner = None
logger.info(f"Initialized BGE-M3 trainer with config: {config.to_dict()}")
def _create_optimizer(self) -> torch.optim.Optimizer:
"""Create optimizer with parameter groups"""
# Separate parameters for different learning rates if needed
param_groups = [
{
'params': [p for n, p in self.model.named_parameters()
if 'colbert_linear' not in n and 'sparse_linear' not in n],
'lr': self.config.learning_rate
}
]
# Different LR for task-specific layers
if self.config.use_colbert and hasattr(self.model, 'colbert_linear'):
param_groups.append({
'params': self.model.colbert_linear.parameters(),
'lr': self.config.learning_rate * 2 # Higher LR for new layers
})
if self.config.use_sparse and hasattr(self.model, 'sparse_linear'):
param_groups.append({
'params': self.model.sparse_linear.parameters(),
'lr': self.config.learning_rate * 2
})
optimizer = AdamW(
param_groups,
eps=self.config.adam_epsilon,
weight_decay=self.config.weight_decay,
betas=(self.config.adam_beta1, self.config.adam_beta2)
)
return optimizer
def create_scheduler(self, num_training_steps: int):
"""Create learning rate scheduler"""
if get_linear_schedule_with_warmup is None:
logger.warning("transformers not available, using constant learning rate")
return None
num_warmup_steps = self.config.warmup_steps or int(num_training_steps * self.config.warmup_ratio)
self.scheduler = get_linear_schedule_with_warmup(
self.optimizer,
num_warmup_steps=num_warmup_steps,
num_training_steps=num_training_steps
)
return self.scheduler
def compute_loss(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
"""Compute M3 loss with multiple objectives"""
# Update hard negative mining if hook is available
if hasattr(self, '_mining_hook'):
self._mining_hook()
# Detect batch format and handle accordingly
if 'pos_input_ids' in batch and 'neg_input_ids' in batch:
# This is triplet format from FastM3Dataset
return self._compute_triplet_loss(batch)
elif 'passage_input_ids' in batch:
# This is standard embedding format
return self._compute_embedding_loss(batch)
else:
raise ValueError(f"Unknown batch format. Keys: {list(batch.keys())}")
def _compute_triplet_loss(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
"""Compute loss for triplet format data (FastM3Dataset)"""
# Get query embeddings
query_outputs = self.model.forward(
input_ids=batch['query_input_ids'],
attention_mask=batch['query_attention_mask'],
return_dense=self.config.use_dense,
return_sparse=self.config.use_sparse,
return_colbert=self.config.use_colbert,
return_dict=True
)
# Get positive embeddings
pos_outputs = self.model.forward(
input_ids=batch['pos_input_ids'],
attention_mask=batch['pos_attention_mask'],
return_dense=self.config.use_dense,
return_sparse=self.config.use_sparse,
return_colbert=self.config.use_colbert,
return_dict=True
)
# Get negative embeddings
neg_outputs = self.model.forward(
input_ids=batch['neg_input_ids'],
attention_mask=batch['neg_attention_mask'],
return_dense=self.config.use_dense,
return_sparse=self.config.use_sparse,
return_colbert=self.config.use_colbert,
return_dict=True
)
losses = {}
# Dense loss (InfoNCE with explicit pos/neg)
if 'dense' in query_outputs:
query_dense = query_outputs['dense'].float() # type: ignore[index]
pos_dense = pos_outputs['dense'].float() # type: ignore[index]
neg_dense = neg_outputs['dense'].float() # type: ignore[index]
# For triplet format, reshape negative embeddings to expected 3D format
# InfoNCE expects negative_embeddings: [batch_size, num_negatives, embedding_dim]
# But we have single negatives: [batch_size, embedding_dim]
# Reshape to [batch_size, 1, embedding_dim]
neg_dense_3d = neg_dense.unsqueeze(1) # [batch_size, 1, embedding_dim]
dense_loss = self.infonce_loss(
query_dense,
pos_dense,
neg_dense_3d # Now properly shaped 3D tensor
)
losses['dense'] = dense_loss
# Sparse loss
if self.config.use_sparse and 'sparse' in query_outputs:
from models.losses import SparseLoss
sparse_loss_fn = SparseLoss(
flops_loss_weight=1e-3,
sparse_reg_weight=1e-4
)
# Combine pos and neg for sparse loss
combined_sparse = torch.cat([pos_outputs['sparse'], neg_outputs['sparse']], dim=0) # type: ignore[index]
query_sparse_repeated = query_outputs['sparse'].repeat(2, 1) # type: ignore[index]
labels = torch.cat([torch.ones(len(pos_outputs['sparse'])), # type: ignore[index]
torch.zeros(len(neg_outputs['sparse']))], dim=0).to(query_sparse_repeated.device) # type: ignore[index]
sparse_loss, _ = sparse_loss_fn(
query_sparse_repeated,
combined_sparse,
labels
)
losses['sparse'] = sparse_loss
# ColBERT loss
if self.config.use_colbert and 'colbert' in query_outputs:
from models.losses import ColBERTLoss
colbert_loss_fn = ColBERTLoss(temperature=self.config.temperature)
# Combine pos and neg for ColBERT loss
combined_colbert = torch.cat([pos_outputs['colbert'], neg_outputs['colbert']], dim=0) # type: ignore[index]
query_colbert_repeated = query_outputs['colbert'].repeat(2, 1, 1) # type: ignore[index]
labels = torch.cat([torch.ones(len(pos_outputs['colbert'])), # type: ignore[index]
torch.zeros(len(neg_outputs['colbert']))], dim=0).to(query_colbert_repeated.device) # type: ignore[index]
colbert_loss = colbert_loss_fn(
query_colbert_repeated,
combined_colbert,
labels
)
losses['colbert'] = colbert_loss
# Combine losses
total_loss, _ = self.multi_task_loss(losses)
return total_loss
def _compute_embedding_loss(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
"""Compute loss for standard embedding format data"""
# Get query embeddings
query_outputs = self.model.forward(
input_ids=batch['query_input_ids'],
attention_mask=batch['query_attention_mask'],
return_dense=self.config.use_dense,
return_sparse=self.config.use_sparse,
return_colbert=self.config.use_colbert,
return_dict=True
)
# Get passage embeddings - handle both single and multiple passages
if batch['passage_input_ids'].dim() == 3:
# Multiple passages per query [batch, num_passages, seq_len]
batch_size, num_passages, seq_len = batch['passage_input_ids'].shape
passage_input_ids = batch['passage_input_ids'].view(batch_size * num_passages, seq_len)
passage_attention_mask = batch['passage_attention_mask'].view(batch_size * num_passages, seq_len)
multiple_passages = True
else:
# Single passage per query [batch, seq_len]
passage_input_ids = batch['passage_input_ids']
passage_attention_mask = batch['passage_attention_mask']
batch_size = passage_input_ids.shape[0]
num_passages = 1
multiple_passages = False
passage_outputs = self.model.forward(
input_ids=passage_input_ids,
attention_mask=passage_attention_mask,
return_dense=True,
return_sparse=self.config.use_sparse,
return_colbert=self.config.use_colbert,
return_dict=True
)
# Compute losses for different components
losses = {}
# Dense loss (InfoNCE)
if 'dense' in query_outputs and 'dense' in passage_outputs:
# Get embeddings and ensure they're on the same device
query_dense = query_outputs['dense'].float() # type: ignore[index] # [batch_size, embedding_dim]
passage_dense = passage_outputs['dense'].float() # type: ignore[index] # [batch_size * num_passages, embedding_dim]
# Handle multiple passages per query differently for contrastive learning
if multiple_passages:
# Reshape passage embeddings back to [batch_size, num_passages, embedding_dim]
passage_dense = passage_dense.view(batch_size, num_passages, -1)
# For contrastive learning, we need to identify positive passages
# Use the labels to find the first positive passage for each query
labels = batch.get('labels') # [batch_size, num_passages]
if labels is not None:
# Find the first positive passage for each query
positive_indices = []
negative_passages = []
for i in range(batch_size):
query_labels = labels[i] # [num_passages]
positive_mask = query_labels > 0
if positive_mask.any():
# Get first positive passage
pos_idx = positive_mask.nonzero(as_tuple=True)[0][0].item()
positive_indices.append(pos_idx)
# Collect negative passages
neg_mask = ~positive_mask
if neg_mask.any():
neg_embeddings = passage_dense[i][neg_mask] # [num_negatives, embedding_dim]
negative_passages.append(neg_embeddings)
else:
negative_passages.append(None)
else:
# No positive passage, use first passage as positive (fallback)
positive_indices.append(0)
negative_passages.append(passage_dense[i][1:])
# Extract positive embeddings
positive_embeddings = torch.stack([
passage_dense[i, positive_indices[i]] for i in range(batch_size)
]) # [batch_size, embedding_dim]
# Stack negative embeddings (if available)
max_neg = max(neg.shape[0] if neg is not None else 0 for neg in negative_passages)
if max_neg > 0:
# Pad negative embeddings to same size
padded_negatives = []
for neg in negative_passages:
if neg is not None and neg.shape[0] > 0:
if neg.shape[0] < max_neg:
# Pad with zeros
padding = torch.zeros(max_neg - neg.shape[0], neg.shape[1], device=neg.device)
neg = torch.cat([neg, padding], dim=0)
padded_negatives.append(neg)
else:
# No negatives, create zero tensor
padded_negatives.append(torch.zeros(max_neg, passage_dense.shape[-1], device=passage_dense.device))
negative_embeddings = torch.stack(padded_negatives) # [batch_size, max_neg, embedding_dim]
else:
negative_embeddings = None
# Compute InfoNCE loss with explicit positives and negatives
dense_loss = self.infonce_loss(
query_dense,
positive_embeddings,
negative_embeddings
)
else:
# No labels provided, use first passage as positive, rest as negatives
positive_embeddings = passage_dense[:, 0, :] # [batch_size, embedding_dim]
if num_passages > 1:
negative_embeddings = passage_dense[:, 1:, :] # [batch_size, num_passages-1, embedding_dim]
else:
negative_embeddings = None
dense_loss = self.infonce_loss(
query_dense,
positive_embeddings,
negative_embeddings
)
else:
# Single passage per query - use standard in-batch negatives
# InfoNCE will create contrastive pairs using in-batch negatives
dense_loss = self.infonce_loss(
query_dense,
passage_dense,
None # No explicit negatives, will use in-batch
)
losses['dense'] = dense_loss
# Sparse loss (if using sparse)
if self.config.use_sparse and 'sparse' in query_outputs and 'sparse' in passage_outputs:
from models.losses import SparseLoss
sparse_loss_fn = SparseLoss(
flops_loss_weight=1e-3,
sparse_reg_weight=1e-4
)
query_sparse = query_outputs['sparse'] # type: ignore[index] # [batch_size, query_seq_len]
passage_sparse = passage_outputs['sparse'] # type: ignore[index] # [batch_size * num_passages, doc_seq_len]
# For sparse loss, we need to handle the contrastive learning differently
# Since sparse embeddings have different sequence lengths, we'll use the same approach as dense
if multiple_passages:
# Use only the first positive passage for each query
labels = batch.get('labels')
if labels is not None:
positive_indices = []
for i in range(batch_size):
query_labels = labels[i]
positive_mask = query_labels > 0
if positive_mask.any():
pos_idx = positive_mask.nonzero(as_tuple=True)[0][0].item()
else:
pos_idx = 0
positive_indices.append(pos_idx + i * num_passages)
# Extract positive sparse embeddings
positive_sparse = passage_sparse[positive_indices] # [batch_size, doc_seq_len]
else:
# Use first passage as positive
positive_sparse = passage_sparse[::num_passages] # [batch_size, doc_seq_len]
# Create contrastive labels for sparse loss
contrastive_labels = torch.arange(batch_size, device=query_sparse.device, dtype=torch.long)
sparse_loss, _ = sparse_loss_fn(
query_sparse,
positive_sparse,
contrastive_labels
)
else:
# Single passage per query
contrastive_labels = torch.arange(batch_size, device=query_sparse.device, dtype=torch.long)
sparse_loss, _ = sparse_loss_fn(
query_sparse,
passage_sparse,
contrastive_labels
)
losses['sparse'] = sparse_loss
# ColBERT loss (if using ColBERT)
if self.config.use_colbert and 'colbert' in query_outputs and 'colbert' in passage_outputs:
from models.losses import ColBERTLoss
colbert_loss_fn = ColBERTLoss(temperature=self.config.temperature)
query_colbert = query_outputs['colbert'] # type: ignore[index] # [batch_size, query_seq_len, dim]
passage_colbert = passage_outputs['colbert'] # type: ignore[index] # [batch_size * num_passages, doc_seq_len, dim]
# For ColBERT, we need pairwise interactions
if multiple_passages:
# Use only positive passages
labels = batch.get('labels')
if labels is not None:
positive_indices = []
for i in range(batch_size):
query_labels = labels[i]
positive_mask = query_labels > 0
if positive_mask.any():
pos_idx = positive_mask.nonzero(as_tuple=True)[0][0].item()
else:
pos_idx = 0
positive_indices.append(pos_idx + i * num_passages)
positive_colbert = passage_colbert[positive_indices]
else:
positive_colbert = passage_colbert[::num_passages]
# Binary labels for ColBERT (1 for positive pairs)
colbert_labels = torch.ones(batch_size, device=query_colbert.device)
else:
positive_colbert = passage_colbert
colbert_labels = torch.ones(batch_size, device=query_colbert.device)
colbert_loss = colbert_loss_fn(
query_colbert,
positive_colbert,
colbert_labels
)
losses['colbert'] = colbert_loss
# Knowledge distillation loss (only if explicitly enabled)
if (getattr(self.config, 'use_self_distill', False) and
self.global_step >= getattr(self.config, 'self_distill_start_step', 0)):
try:
# Compute teacher scores (weighted combination)
teacher_scores = self._compute_teacher_scores(query_outputs, passage_outputs) # type: ignore[arg-type]
student_scores = self.model.compute_similarity(
query_outputs, passage_outputs, mode='dense'
)
kd_loss, _ = self.kd_loss(
{'dense': student_scores},
{'dense': teacher_scores}
)
losses['kd'] = kd_loss
except Exception as e:
logger.warning(f"Knowledge distillation failed: {e}")
# Combine losses
total_loss, _ = self.multi_task_loss(losses)
return total_loss
def _compute_teacher_scores(
self,
query_outputs: Dict[str, torch.Tensor],
passage_outputs: Dict[str, torch.Tensor]
) -> torch.Tensor:
"""Compute teacher scores from multiple signals"""
scores = []
weights = []
# Get original batch size for reshaping
query_batch_size = query_outputs['dense'].shape[0] if 'dense' in query_outputs else 1
passage_batch_size = passage_outputs['dense'].shape[0] if 'dense' in passage_outputs else 1
# Check if we have multiple passages per query
multiple_passages = passage_batch_size > query_batch_size
if multiple_passages:
# For multiple passages, we need to compute scores only for positive pairs
# Use only the first passage for each query as a simplified teacher score
num_passages = passage_batch_size // query_batch_size
# Dense scores (use only first passage per query)
if 'dense' in query_outputs and 'dense' in passage_outputs:
query_dense = query_outputs['dense'] # [batch_size, dim]
passage_dense = passage_outputs['dense'] # [batch_size * num_passages, dim]
# Extract first passage for each query
positive_passages = passage_dense[::num_passages] # [batch_size, dim]
try:
dense_scores = self.model.compute_similarity(
{'dense': query_dense},
{'dense': positive_passages},
mode='dense'
)
scores.append(dense_scores)
weights.append(self.config.dense_weight)
except Exception as e:
logger.warning(f"Dense similarity computation failed: {e}")
# Sparse scores (use only first passage per query)
if 'sparse' in query_outputs and 'sparse' in passage_outputs:
query_sparse = query_outputs['sparse'] # [batch_size, query_seq_len]
passage_sparse = passage_outputs['sparse'] # [batch_size * num_passages, doc_seq_len]
# Extract first passage for each query
positive_passages = passage_sparse[::num_passages] # [batch_size, doc_seq_len]
try:
sparse_scores = self.model.compute_similarity(
{'sparse': query_sparse},
{'sparse': positive_passages},
mode='sparse'
)
scores.append(sparse_scores)
weights.append(self.config.sparse_weight)
except Exception as e:
logger.warning(f"Sparse similarity computation failed: {e}")
# ColBERT scores (use only first passage per query)
if 'colbert' in query_outputs and 'colbert' in passage_outputs:
query_colbert = query_outputs['colbert'] # [batch_size, query_seq_len, dim]
passage_colbert = passage_outputs['colbert'] # [batch_size * num_passages, doc_seq_len, dim]
# Extract first passage for each query
positive_passages = passage_colbert[::num_passages] # [batch_size, doc_seq_len, dim]
try:
colbert_scores = self.model.compute_similarity(
{'colbert': query_colbert},
{'colbert': positive_passages},
mode='colbert'
)
scores.append(colbert_scores)
weights.append(self.config.colbert_weight)
except Exception as e:
logger.warning(f"ColBERT similarity computation failed: {e}")
else:
# Single passage per query - original logic
# Dense scores
if 'dense' in query_outputs:
try:
dense_scores = self.model.compute_similarity(
{'dense': query_outputs['dense']},
{'dense': passage_outputs['dense']},
mode='dense'
)
scores.append(dense_scores)
weights.append(self.config.dense_weight)
except Exception as e:
logger.warning(f"Dense similarity computation failed: {e}")
# Sparse scores
if 'sparse' in query_outputs:
try:
sparse_scores = self.model.compute_similarity(
{'sparse': query_outputs['sparse']},
{'sparse': passage_outputs['sparse']},
mode='sparse'
)
scores.append(sparse_scores)
weights.append(self.config.sparse_weight)
except Exception as e:
logger.warning(f"Sparse similarity computation failed: {e}")
# ColBERT scores
if 'colbert' in query_outputs:
try:
colbert_scores = self.model.compute_similarity(
{'colbert': query_outputs['colbert']},
{'colbert': passage_outputs['colbert']},
mode='colbert'
)
scores.append(colbert_scores)
weights.append(self.config.colbert_weight)
except Exception as e:
logger.warning(f"ColBERT similarity computation failed: {e}")
# Weighted combination
if scores:
weights = torch.tensor(weights, device=scores[0].device)
weights = weights / weights.sum()
teacher_scores = torch.stack([w * s for w, s in zip(weights, scores)]).sum(dim=0)
return teacher_scores
# Fallback to zero scores if no valid similarity computation
fallback_size = query_batch_size if not multiple_passages else query_batch_size
return torch.zeros(fallback_size, device=query_outputs[list(query_outputs.keys())[0]].device)
def evaluate_step(self, batch: Dict[str, torch.Tensor]) -> Dict[str, float]:
"""Evaluate a single batch"""
# Get query embeddings
query_outputs = self.model.forward(
input_ids=batch['query_input_ids'],
attention_mask=batch['query_attention_mask'],
return_dense=True,
return_dict=True
)
# Get passage embeddings
passage_outputs = self.model.forward(
input_ids=batch['passage_input_ids'],
attention_mask=batch['passage_attention_mask'],
return_dense=True,
return_dict=True
)
# Compute loss
loss = self.compute_loss(batch)
# Compute similarity scores
scores = self.model.compute_similarity(
query_outputs, passage_outputs, mode='dense'
)
# Basic metrics
labels = batch.get('labels', torch.zeros(scores.shape[0]))
if labels.dim() > 1:
labels = labels.view(-1)
predictions = (scores > 0.5).float()
accuracy = (predictions == labels.float()).float().mean()
return {
'loss': loss.item(),
'accuracy': accuracy.item(),
'mean_score': scores.mean().item()
}
def train(
self,
train_dataset: BGEM3Dataset,
eval_dataset: Optional[BGEM3Dataset] = None,
resume_from_checkpoint: Optional[str] = None
):
"""
Main training loop for M3Trainer with robust error handling, config-driven parameters, and progress tracking.
"""
try:
train_dataloader = self._create_dataloader(train_dataset, is_train=True)
eval_dataloader = self._create_dataloader(eval_dataset, is_train=False) if eval_dataset else None
# Calculate total training steps
num_update_steps_per_epoch = max(1, len(train_dataloader) // self.config.gradient_accumulation_steps) # type: ignore[arg-type]
total_training_steps = num_update_steps_per_epoch * self.config.num_train_epochs
# Create scheduler with total steps
self.create_scheduler(total_training_steps)
if resume_from_checkpoint:
self.load_checkpoint(resume_from_checkpoint)
logger.info(f"Starting training for {self.config.num_train_epochs} epochs")
logger.info(f"Total training steps: {total_training_steps}")
# Use parent's train method
super().train(train_dataloader, eval_dataloader, self.config.num_train_epochs) # type: ignore[arg-type]
except Exception as e:
logger.error(f"M3 training failed: {e}")
raise
finally:
# Cleanup
if hasattr(self, 'hard_negative_miner') and self.hard_negative_miner:
self.hard_negative_miner.stop_async_updates()
def _create_dataloader(
self,
dataset: Optional[BGEM3Dataset],
is_train: bool = True
) -> Optional[DataLoader]:
"""Create data loader"""
if dataset is None:
return None
# Import both collate functions
from data.dataset import collate_embedding_batch
# Choose appropriate collate function based on dataset type
collate_fn = collate_embedding_batch
num_workers = self.config.dataloader_num_workers
batch_size = self.config.per_device_train_batch_size if is_train else self.config.per_device_eval_batch_size
# Setup distributed sampler if needed
sampler = None
if self.is_distributed:
from torch.utils.data.distributed import DistributedSampler
sampler = DistributedSampler(dataset, shuffle=is_train)
dataloader = DataLoader(
dataset,
batch_size=batch_size,
shuffle=is_train and sampler is None,
sampler=sampler,
num_workers=num_workers,
pin_memory=self.config.dataloader_pin_memory,
drop_last=is_train and len(dataset) > batch_size, # Only drop last if we have multiple batches
collate_fn=collate_fn
)
return dataloader
def _train_epoch(self, dataloader: DataLoader, epoch: int) -> Dict[str, float]:
"""Train one epoch - use parent's method with hard negative mining integration"""
# Add hard negative mining updates before calling parent's method
if self.hard_negative_miner:
# Set up a hook to update hard negatives during training
def mining_hook():
if self.hard_negative_miner:
self.hard_negative_miner.update_index_if_needed(self.global_step, self.model)
# Store the hook for use in compute_loss
self._mining_hook = mining_hook
# Use parent's robust training loop
return super().train_epoch(dataloader, epoch)
def evaluate(self, dataloader: DataLoader) -> Dict[str, float]:
"""Evaluate the model"""
self.model.eval()
all_query_embeddings = []
all_passage_embeddings = []
all_labels = []
with torch.no_grad():
with tqdm(dataloader, desc="Evaluating") as tbar:
for batch in tbar:
try:
batch = self._prepare_batch(batch)
query_outputs = self.model.forward(
input_ids=batch['query_input_ids'],
attention_mask=batch['query_attention_mask'],
return_dense=True,
return_dict=True
)
passage_outputs = self.model.forward(
input_ids=batch['passage_input_ids'],
attention_mask=batch['passage_attention_mask'],
return_dense=True,
return_dict=True
)
all_query_embeddings.append(query_outputs['dense'].cpu()) # type: ignore[index]
all_passage_embeddings.append(passage_outputs['dense'].cpu()) # type: ignore[index]
all_labels.append(batch['labels'].cpu())
except Exception as e:
logger.error(f"Error in M3 evaluation step: {e}")
continue
if not all_query_embeddings:
logger.warning("No valid evaluation batches processed")
return {'loss': float('inf'), 'accuracy': 0.0}
# Concatenate all embeddings
all_query_embeddings = torch.cat(all_query_embeddings, dim=0)
all_passage_embeddings = torch.cat(all_passage_embeddings, dim=0)
all_labels = torch.cat(all_labels, dim=0)
# Compute metrics
try:
metrics = compute_retrieval_metrics(
all_query_embeddings.numpy(),
all_passage_embeddings.numpy(),
all_labels.numpy(),
k_values=[1, 3, 5, 10]
)
except Exception as e:
logger.error(f"Error computing retrieval metrics: {e}")
metrics = {'loss': float('inf'), 'accuracy': 0.0}
return metrics
def save_model(self, save_path: str):
"""Save the trained model"""
try:
os.makedirs(save_path, exist_ok=True)
# Save model
model_to_save = self.model.module if hasattr(self.model, 'module') else self.model
model_to_save.save_pretrained(save_path) # type: ignore[attr-defined]
# Save tokenizer
self.tokenizer.save_pretrained(save_path) # type: ignore[attr-defined]
# Save config
self.config.save(os.path.join(save_path, 'training_config.json'))
logger.info(f"Model saved to {save_path}")
except Exception as e:
logger.error(f"Error saving model: {e}")
raise
def __del__(self):
"""Cleanup when trainer is destroyed"""
if hasattr(self, 'hard_negative_miner') and self.hard_negative_miner:
self.hard_negative_miner.stop_async_updates()

View File

@@ -0,0 +1,489 @@
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.optim import AdamW
from typing import Dict, Optional, List, Tuple, Any
import logging
import numpy as np
from torch.nn import functional as F
from tqdm import tqdm
import os
from .trainer import BaseTrainer
from models.bge_reranker import BGERerankerModel
from models.losses import ListwiseCrossEntropyLoss, PairwiseMarginLoss
from data.dataset import BGERerankerDataset
from config.reranker_config import BGERerankerConfig
from evaluation.metrics import compute_reranker_metrics
from utils.distributed import is_main_process
# AMP imports
from torch.cuda.amp.autocast_mode import autocast
from torch.cuda.amp.grad_scaler import GradScaler
# Scheduler import
try:
from transformers import get_linear_schedule_with_warmup
except ImportError:
get_linear_schedule_with_warmup = None
logger = logging.getLogger(__name__)
class BGERerankerTrainer(BaseTrainer):
"""
Trainer for BGE-Reranker cross-encoder model
Notes:
- All parameters (config, model, tokenizer, optimizer, scheduler, device, etc.) should be passed from config-driven training scripts.
- This class does not load config directly; pass config values from your main script for consistency.
"""
def __init__(
self,
config: BGERerankerConfig,
model: Optional[BGERerankerModel] = None,
tokenizer=None,
optimizer: Optional[torch.optim.Optimizer] = None,
scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
device: Optional[torch.device] = None
):
# Initialize model if not provided
if model is None:
model = BGERerankerModel(
model_name_or_path=config.model_name_or_path,
num_labels=config.num_labels,
cache_dir=config.cache_dir,
use_fp16=config.fp16,
gradient_checkpointing=config.gradient_checkpointing
)
# Initialize optimizer if not provided
if optimizer is None:
optimizer = self._create_optimizer(model, config)
# Initialize parent class
super().__init__(config, model, optimizer, scheduler, device)
self.tokenizer = tokenizer
self.config = config
# Initialize losses based on config
if config.loss_type == "listwise_ce":
self.loss_fn = ListwiseCrossEntropyLoss(
temperature=getattr(config, 'temperature', 1.0)
)
elif config.loss_type == "pairwise_margin":
self.loss_fn = PairwiseMarginLoss(margin=config.margin)
elif config.loss_type == "combined":
self.listwise_loss = ListwiseCrossEntropyLoss(
temperature=getattr(config, 'temperature', 1.0)
)
self.pairwise_loss = PairwiseMarginLoss(margin=config.margin)
else:
self.loss_fn = nn.BCEWithLogitsLoss()
# Knowledge distillation setup
self.use_kd = config.use_knowledge_distillation
if self.use_kd:
self.kd_loss_fn = nn.KLDivLoss(reduction='batchmean')
self.kd_temperature = config.distillation_temperature
self.kd_loss_weight = config.distillation_alpha
# Hard negative configuration
self.hard_negative_ratio = config.hard_negative_ratio
self.use_denoising = config.use_denoising
self.denoise_threshold = getattr(config, 'denoise_confidence_threshold', 0.5)
def _create_optimizer(self, model: nn.Module, config: BGERerankerConfig) -> torch.optim.Optimizer:
"""Create optimizer with proper parameter groups"""
# Separate parameters
no_decay = ['bias', 'LayerNorm.weight']
optimizer_grouped_parameters = [
{
'params': [p for n, p in model.named_parameters()
if not any(nd in n for nd in no_decay)],
'weight_decay': config.weight_decay,
},
{
'params': [p for n, p in model.named_parameters()
if any(nd in n for nd in no_decay)],
'weight_decay': 0.0,
}
]
optimizer = AdamW(
optimizer_grouped_parameters,
lr=config.learning_rate,
eps=config.adam_epsilon,
betas=(config.adam_beta1, config.adam_beta2)
)
return optimizer
def create_scheduler(self, num_training_steps: int) -> Optional[object]:
"""Create learning rate scheduler"""
if get_linear_schedule_with_warmup is None:
logger.warning("transformers not available, using constant learning rate")
return None
num_warmup_steps = self.config.warmup_steps or int(num_training_steps * self.config.warmup_ratio)
self.scheduler = get_linear_schedule_with_warmup(
self.optimizer,
num_warmup_steps=num_warmup_steps,
num_training_steps=num_training_steps
)
return self.scheduler
def compute_loss(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
"""Compute reranker loss"""
# Get model outputs
outputs = self.model.forward(
input_ids=batch['input_ids'],
attention_mask=batch['attention_mask'],
return_dict=True
)
if isinstance(outputs, dict):
logits = outputs['logits']
else:
logits = outputs.scores
labels = batch['labels']
# Handle padding labels
if labels.dim() > 1:
# Remove padding (-1 labels)
valid_mask = labels != -1
if valid_mask.any():
logits = logits[valid_mask]
labels = labels[valid_mask]
else:
# If all labels are padding, return zero loss
return torch.tensor(0.0, device=logits.device, requires_grad=True)
# Handle different loss types
if self.config.loss_type == "listwise_ce":
# Listwise loss expects grouped format
try:
groups = [self.config.train_group_size] * (len(labels) // self.config.train_group_size)
if groups:
loss = self.loss_fn(logits.squeeze(-1), labels.float(), groups)
else:
loss = F.binary_cross_entropy_with_logits(logits.squeeze(-1), labels.float())
except Exception as e:
# Fallback to BCE with specific error logging
logger.warning(f"Listwise CE loss failed: {e}. Falling back to BCE.")
loss = F.binary_cross_entropy_with_logits(logits.squeeze(-1), labels.float())
elif self.config.loss_type == "pairwise_margin":
# Extract positive and negative scores
pos_mask = labels == 1
neg_mask = labels == 0
if pos_mask.any() and neg_mask.any():
groups = [self.config.train_group_size] * (len(labels) // self.config.train_group_size)
try:
loss = self.loss_fn(logits.squeeze(-1), labels.float(), groups)
except Exception as e:
# Fallback to BCE with specific error logging
logger.warning(f"Pairwise margin loss failed: {e}. Falling back to BCE.")
loss = F.binary_cross_entropy_with_logits(logits.squeeze(-1), labels.float())
else:
# Fallback to BCE if no proper pairs
logger.warning("No positive-negative pairs found for pairwise loss. Using BCE.")
loss = F.binary_cross_entropy_with_logits(logits.squeeze(-1), labels.float())
elif self.config.loss_type == "combined":
# Combine listwise and pairwise losses
try:
groups = [self.config.train_group_size] * (len(labels) // self.config.train_group_size)
if groups:
listwise_loss = self.listwise_loss(logits.squeeze(-1), labels.float(), groups)
pairwise_loss = self.pairwise_loss(logits.squeeze(-1), labels.float(), groups)
loss = self.config.listwise_weight * listwise_loss + self.config.pairwise_weight * pairwise_loss
else:
loss = F.binary_cross_entropy_with_logits(logits.squeeze(-1), labels.float())
except Exception as e:
# Fallback to BCE with specific error logging
logger.warning(f"Combined loss computation failed: {e}. Falling back to BCE.")
loss = F.binary_cross_entropy_with_logits(logits.squeeze(-1), labels.float())
else:
# Default binary cross-entropy
loss = F.binary_cross_entropy_with_logits(logits.squeeze(-1), labels.float())
# Add knowledge distillation loss if enabled
if self.use_kd and 'teacher_scores' in batch:
teacher_scores = batch['teacher_scores']
student_probs = F.log_softmax(logits / self.kd_temperature, dim=-1)
teacher_probs = F.softmax(teacher_scores / self.kd_temperature, dim=-1)
kd_loss = self.kd_loss_fn(student_probs, teacher_probs) * (self.kd_temperature ** 2)
loss = (1 - self.kd_loss_weight) * loss + self.kd_loss_weight * kd_loss
# Apply denoising if enabled
if self.use_denoising and 'teacher_scores' in batch:
# Denoise by confidence thresholding
confidence_scores = torch.sigmoid(batch['teacher_scores'])
confidence_mask = confidence_scores > self.denoise_threshold
if confidence_mask.any():
loss = loss * confidence_mask.float().mean()
return loss
def evaluate_step(self, batch: Dict[str, torch.Tensor]) -> Dict[str, float]:
"""Evaluate a single batch"""
with torch.no_grad():
# Get model outputs
outputs = self.model.forward(
input_ids=batch['input_ids'],
attention_mask=batch['attention_mask'],
return_dict=True
)
if isinstance(outputs, dict):
logits = outputs['logits']
else:
logits = outputs.scores
labels = batch['labels']
# Handle padding labels
if labels.dim() > 1:
valid_mask = labels != -1
if valid_mask.any():
logits = logits[valid_mask]
labels = labels[valid_mask]
else:
return {'loss': 0.0, 'accuracy': 0.0}
# Compute loss
loss = self.compute_loss(batch)
# Compute metrics based on group size
if hasattr(self.config, 'train_group_size') and self.config.train_group_size > 1:
# Listwise evaluation
groups = [self.config.train_group_size] * (len(labels) // self.config.train_group_size)
if groups:
try:
metrics = compute_reranker_metrics(
logits.squeeze(-1).cpu().numpy(),
labels.cpu().numpy(),
groups=groups,
k_values=[1, 3, 5, 10]
)
metrics['loss'] = loss.item()
except Exception as e:
logger.warning(f"Error computing reranker metrics: {e}")
metrics = {'loss': loss.item(), 'accuracy': 0.0}
else:
metrics = {'loss': loss.item(), 'accuracy': 0.0}
else:
# Binary evaluation
predictions = (torch.sigmoid(logits.squeeze(-1)) > 0.5).float()
accuracy = (predictions == labels.float()).float().mean()
metrics = {
'accuracy': accuracy.item(),
'loss': loss.item()
}
# Add AUC if we have both positive and negative examples
if len(torch.unique(labels)) > 1:
try:
from sklearn.metrics import roc_auc_score
auc = roc_auc_score(labels.cpu().numpy(), torch.sigmoid(logits.squeeze(-1)).cpu().numpy())
metrics['auc'] = float(auc)
except Exception as e:
logger.warning(f"Error computing AUC: {e}")
return metrics
def train_with_hard_negative_mining(
self,
train_dataset: BGERerankerDataset,
eval_dataset: Optional[BGERerankerDataset] = None,
retriever_model: Optional[nn.Module] = None
):
"""
Train with periodic hard negative mining from retriever
"""
if retriever_model is None:
logger.warning("No retriever model provided for hard negative mining. Using standard training.")
train_dataloader = self._create_dataloader(train_dataset, is_train=True)
eval_dataloader = self._create_dataloader(eval_dataset, is_train=False) if eval_dataset else None
# Only proceed if we have a valid training dataloader
if train_dataloader is not None:
return super().train(train_dataloader, eval_dataloader)
else:
logger.error("No valid training dataloader created")
return
# Training loop with hard negative updates
for epoch in range(self.config.num_train_epochs):
# Mine hard negatives at the beginning of each epoch
if epoch > 0:
logger.info(f"Mining hard negatives for epoch {epoch + 1}")
self._mine_hard_negatives(train_dataset, retriever_model)
# Create data loader with new negatives
train_dataloader = self._create_dataloader(train_dataset, is_train=True)
eval_dataloader = self._create_dataloader(eval_dataset, is_train=False) if eval_dataset else None
# Only proceed if we have a valid training dataloader
if train_dataloader is not None:
# Train one epoch
super().train(train_dataloader, eval_dataloader, num_epochs=1)
else:
logger.error(f"No valid training dataloader for epoch {epoch + 1}")
continue
# Update current epoch
self.current_epoch = epoch
def _create_dataloader(self, dataset, is_train: bool = True) -> Optional[DataLoader]:
"""Create data loader"""
if dataset is None:
return None
from data.dataset import collate_reranker_batch
batch_size = self.config.per_device_train_batch_size if is_train else self.config.per_device_eval_batch_size
# Setup distributed sampler if needed
sampler = None
if self.is_distributed:
from torch.utils.data.distributed import DistributedSampler
sampler = DistributedSampler(dataset, shuffle=is_train)
return DataLoader(
dataset,
batch_size=batch_size,
shuffle=is_train and sampler is None,
sampler=sampler,
num_workers=self.config.dataloader_num_workers,
pin_memory=self.config.dataloader_pin_memory,
drop_last=is_train,
collate_fn=collate_reranker_batch
)
def _mine_hard_negatives(
self,
dataset: BGERerankerDataset,
retriever_model: Any
):
"""Mine hard negatives using the retriever model and update the dataset's negatives."""
logger.info("Starting hard negative mining using retriever model...")
num_hard_negatives = getattr(self.config, 'num_hard_negatives', 10)
retriever_batch_size = getattr(self.config, 'retriever_batch_size', 32)
mined_count = 0
# Collect all unique candidate passages from the dataset
all_passages = set()
for item in dataset.data:
all_passages.update(item.get('pos', []))
all_passages.update(item.get('neg', []))
all_passages = list(all_passages)
if not all_passages:
logger.warning("No candidate passages found in dataset for hard negative mining.")
return
# Encode all candidate passages
try:
passage_embeddings = retriever_model.encode(
all_passages,
batch_size=retriever_batch_size,
convert_to_numpy=True
)
except Exception as e:
logger.error(f"Error encoding passages for hard negative mining: {e}")
return
# For each query, mine hard negatives
for idx in range(len(dataset)):
item = dataset.data[idx]
query = item['query']
positives = set(item.get('pos', []))
existing_negs = set(item.get('neg', []))
try:
query_emb = retriever_model.encode([query], batch_size=1, convert_to_numpy=True)[0]
# Compute similarity
scores = np.dot(passage_embeddings, query_emb)
# Rank passages by similarity
ranked_indices = np.argsort(scores)[::-1]
new_negs = []
for i in ranked_indices:
passage = all_passages[int(i)] # Convert to int for indexing
if passage not in positives and passage not in existing_negs:
new_negs.append(passage)
if len(new_negs) >= num_hard_negatives:
break
if new_negs:
item['neg'] = list(existing_negs) + new_negs
mined_count += 1
logger.debug(f"Query idx {idx}: Added {len(new_negs)} hard negatives.")
except Exception as e:
logger.error(f"Error mining hard negatives for query idx {idx}: {e}")
logger.info(f"Hard negative mining complete. Updated negatives for {mined_count} queries.")
def save_model(self, save_path: str):
"""Save the trained model"""
try:
os.makedirs(save_path, exist_ok=True)
# Save model
model_to_save = self.model.module if hasattr(self.model, 'module') else self.model
if hasattr(model_to_save, 'save_pretrained'):
model_to_save.save_pretrained(save_path) # type: ignore
else:
# Fallback to torch.save
torch.save(model_to_save.state_dict(), os.path.join(save_path, 'pytorch_model.bin')) # type: ignore
# Save tokenizer if available
if self.tokenizer and hasattr(self.tokenizer, 'save_pretrained'):
self.tokenizer.save_pretrained(save_path)
# Save config
if hasattr(self.config, 'save'):
self.config.save(os.path.join(save_path, 'training_config.json'))
logger.info(f"Model saved to {save_path}")
except Exception as e:
logger.error(f"Error saving model: {e}")
raise
def export_for_inference(self, output_path: str):
"""Export model for efficient inference"""
if is_main_process():
logger.info(f"Exporting model for inference to {output_path}")
# Save model in inference mode
self.model.eval()
# Save model state
torch.save({
'model_state_dict': self.model.state_dict(),
'config': self.config.to_dict(),
'model_type': 'bge_reranker'
}, output_path)
# Also save in HuggingFace format if possible
try:
hf_path = output_path.replace('.pt', '_hf')
self.save_model(hf_path)
except Exception as e:
logger.warning(f"Failed to save in HuggingFace format: {e}")
logger.info(f"Model exported successfully")

534
training/trainer.py Normal file
View File

@@ -0,0 +1,534 @@
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler
from typing import Dict, Optional, List, Tuple, Any, Union, cast
import os
import json
import logging
from tqdm import tqdm
from abc import ABC, abstractmethod
import time
from datetime import datetime
import numpy as np
from torch.cuda.amp.autocast_mode import autocast
from torch.cuda.amp.grad_scaler import GradScaler
from torch.nn.parallel import DataParallel, DistributedDataParallel as DDP
from torch.nn.utils.clip_grad import clip_grad_norm_
from collections import defaultdict
from utils.checkpoint import CheckpointManager
from utils.logging import MetricsLogger, TensorBoardLogger, ProgressLogger
from utils.distributed import (
setup_distributed, cleanup_distributed, is_main_process,
reduce_dict, synchronize, get_rank, get_world_size
)
from config.base_config import BaseConfig
try:
from transformers import get_linear_schedule_with_warmup
except ImportError:
get_linear_schedule_with_warmup = None
logger = logging.getLogger(__name__)
class BaseTrainer(ABC):
"""
Abstract base trainer class for BGE models
Notes:
- All parameters (config, model, optimizer, scheduler, device, etc.) should be passed from config-driven training scripts.
- This class does not load config directly; pass config values from your main script for consistency.
"""
def __init__(
self,
config: BaseConfig,
model: nn.Module,
optimizer: Optimizer,
scheduler: Optional[_LRScheduler] = None,
device: Optional[torch.device] = None
):
self.config = config
self.model = model
self.optimizer = optimizer
self.scheduler = scheduler
# Setup device
if device is None:
self.device = torch.device(config.device)
else:
self.device = device
# Move model to device
self.model.to(self.device)
# Setup distributed training if needed
self.is_distributed = config.local_rank != -1
if self.is_distributed:
setup_distributed()
self.model = self._setup_distributed_model()
# Setup mixed precision training
self.use_amp = config.fp16 or config.bf16
if self.use_amp and config.bf16:
# Use bfloat16 for better numerical stability on RTX5090/A100
self.amp_dtype = torch.bfloat16
else:
self.amp_dtype = torch.float16
self.scaler = GradScaler() if self.use_amp and not config.bf16 else None
# Memory optimization settings
self.empty_cache_steps = getattr(config, 'empty_cache_steps', 100)
self.gradient_checkpointing = getattr(config, 'gradient_checkpointing', False)
# Setup logging
self.metrics_logger = MetricsLogger(config.output_dir) if is_main_process() else None
self.tb_logger = TensorBoardLogger(
os.path.join(config.logging_dir, 'tensorboard'),
enabled=is_main_process()
)
self.progress_logger = None # Will be initialized in train()
# Setup checkpointing
self.checkpoint_manager = CheckpointManager(
output_dir=config.output_dir,
save_total_limit=config.save_total_limit
) if is_main_process() else None
# Training state
self.global_step = 0
self.current_epoch = 0
self.best_metric = float('-inf') if config.greater_is_better else float('inf')
self.training_history = []
def _setup_distributed_model(self) -> nn.Module:
"""Setup model for distributed training"""
from torch.nn.parallel import DistributedDataParallel as DDP
model = DDP(
self.model,
device_ids=[self.config.local_rank],
output_device=self.config.local_rank,
find_unused_parameters=True
)
return model
@abstractmethod
def compute_loss(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
"""Compute loss for a batch - to be implemented by subclasses"""
pass
@abstractmethod
def evaluate_step(self, batch: Dict[str, torch.Tensor]) -> Dict[str, float]:
"""Evaluate a single batch - to be implemented by subclasses"""
pass
def train(
self,
train_dataloader: DataLoader,
eval_dataloader: Optional[DataLoader] = None,
num_epochs: Optional[int] = None
) -> Dict[str, Any]:
"""
Main training loop with robust error handling, config-driven parameters, and progress tracking.
"""
try:
num_epochs = num_epochs or self.config.num_train_epochs
total_steps = len(train_dataloader) * num_epochs
self.progress_logger = ProgressLogger(
total_steps=total_steps,
log_interval=self.config.logging_steps
)
if is_main_process():
logger.info("***** Running training *****")
logger.info(f" Num examples = {len(train_dataloader.dataset)}") # type: ignore[attr-defined]
logger.info(f" Num epochs = {num_epochs}")
logger.info(f" Batch size = {self.config.per_device_train_batch_size}")
logger.info(f" Gradient accumulation steps = {self.config.gradient_accumulation_steps}")
logger.info(f" Total optimization steps = {total_steps}")
logger.info(f" Device = {self.device}")
if self.is_distributed:
logger.info(f" Distributed training with {get_world_size()} GPUs")
for epoch in range(self.current_epoch, num_epochs):
self.current_epoch = epoch
epoch_start_time = time.time()
train_metrics = self.train_epoch(train_dataloader, epoch)
# Add metrics to training history
epoch_metrics = {
'epoch': epoch + 1,
'train_metrics': train_metrics,
'global_step': self.global_step,
'learning_rate': self.optimizer.param_groups[0]['lr'] if self.optimizer else 0,
}
if eval_dataloader is not None and (epoch + 1) % self.config.eval_steps == 0:
eval_metrics = self.evaluate(eval_dataloader)
epoch_metrics['eval_metrics'] = eval_metrics
current_metric = eval_metrics.get(self.config.metric_for_best_model, 0)
is_best = self._is_better_metric(current_metric)
if is_best:
self.best_metric = current_metric
if is_main_process() and self.checkpoint_manager:
self._save_checkpoint(is_best=is_best, metrics=eval_metrics)
if is_main_process():
logger.info(f"Epoch {epoch + 1} evaluation: {eval_metrics}")
# Add to training history
self.training_history.append(epoch_metrics)
if is_main_process() and (epoch + 1) % self.config.save_steps == 0:
self._save_checkpoint(metrics=train_metrics)
epoch_time = time.time() - epoch_start_time
if is_main_process():
logger.info(f"Epoch {epoch + 1} completed in {epoch_time:.2f} seconds")
logger.info(f"Epoch {epoch + 1} metrics: {train_metrics}")
if is_main_process():
self._save_checkpoint(metrics=train_metrics)
self._save_training_summary()
# Close TensorBoard logger
self.tb_logger.close()
if self.is_distributed:
cleanup_distributed()
return self.training_history # type: ignore[return-value]
except Exception as e:
logger.error(f"Training failed: {e}")
raise
def train_epoch(
self,
dataloader: DataLoader,
epoch: int
) -> Dict[str, float]:
"""
Train for one epoch with comprehensive error handling and recovery
Args:
dataloader: Training data loader
epoch: Current epoch number
Returns:
Dictionary of training metrics for the epoch
Raises:
RuntimeError: If training fails and cannot recover
"""
self.model.train()
epoch_loss = 0.0
epoch_metrics = defaultdict(float)
num_batches = len(dataloader)
# Track failed batches for recovery
failed_batches = []
max_failures = 3
# Training loop with error recovery
progress_bar = tqdm(
dataloader,
desc=f"Epoch {epoch + 1}",
disable=not is_main_process()
)
for step, batch in enumerate(progress_bar):
try:
# Move batch to device with error handling
try:
batch = self._prepare_batch(batch)
except Exception as e:
logger.warning(f"Failed to prepare batch {step}: {e}")
failed_batches.append((step, 'prepare', str(e)))
if len(failed_batches) > max_failures:
raise RuntimeError(f"Too many batch preparation failures: {len(failed_batches)}")
continue
# Compute loss with error recovery
loss: torch.Tensor
try:
if self.use_amp:
with autocast(enabled=True, dtype=self.amp_dtype):
loss = self.compute_loss(batch)
loss = cast(torch.Tensor, loss)
else:
loss = self.compute_loss(batch)
loss = cast(torch.Tensor, loss)
except Exception as e:
logger.warning(f"Failed to compute loss for batch {step}: {e}")
failed_batches.append((step, 'forward', str(e)))
if len(failed_batches) > max_failures:
raise RuntimeError(f"Too many forward pass failures: {len(failed_batches)}")
continue
# Scale loss for gradient accumulation
loss = loss / self.config.gradient_accumulation_steps
# Backward pass with error handling
try:
if self.scaler:
self.scaler.scale(loss).backward() # type: ignore[attr-defined]
else:
loss.backward() # type: ignore[attr-defined]
except Exception as e:
logger.warning(f"Backward pass failed for batch {step}: {e}")
# Clear gradients and continue
self.optimizer.zero_grad()
failed_batches.append((step, 'backward', str(e)))
if len(failed_batches) > max_failures:
raise RuntimeError(f"Too many backward pass failures: {len(failed_batches)}")
continue
# Update weights with error handling
if (step + 1) % self.config.gradient_accumulation_steps == 0:
try:
# Gradient clipping
if self.scaler:
self.scaler.unscale_(self.optimizer)
clip_grad_norm_(
self.model.parameters(),
self.config.max_grad_norm
)
# Optimizer step
if self.scaler:
self.scaler.step(self.optimizer)
self.scaler.update()
else:
self.optimizer.step()
# Scheduler step
if self.scheduler:
self.scheduler.step()
# Update global step
self.global_step += 1
# Clear cache periodically for memory efficiency
if self.empty_cache_steps > 0 and self.global_step % self.empty_cache_steps == 0:
self._clear_device_cache()
# Zero gradients
self.optimizer.zero_grad()
except Exception as e:
logger.warning(f"Optimizer update failed at step {step}: {e}")
# Reset optimizer state
self.optimizer.zero_grad()
if self.scaler:
self.scaler.update()
failed_batches.append((step, 'optimizer', str(e)))
continue
# Logging with error handling
if self.global_step % self.config.logging_steps == 0:
try:
current_loss = loss.item() * self.config.gradient_accumulation_steps
epoch_loss += current_loss
# Update progress bar
progress_bar.set_postfix({
'loss': f'{current_loss:.4f}',
'lr': f'{self.scheduler.get_last_lr()[0]:.2e}' if self.scheduler else 'N/A',
'failures': len(failed_batches)
})
# Log to TensorBoard
if self.tb_logger:
self.tb_logger.log_scalar('train/loss', current_loss, self.global_step)
self.tb_logger.log_scalar('train/learning_rate',
self.scheduler.get_last_lr()[0] if self.scheduler else self.optimizer.param_groups[0]['lr'],
self.global_step)
except Exception as e:
logger.debug(f"Logging failed: {e}")
# Checkpoint saving with error handling
if self.config.save_steps > 0 and self.global_step % self.config.save_steps == 0 and is_main_process():
try:
if self.checkpoint_manager:
state_dict = {
'model_state_dict': self.model.state_dict(),
'optimizer_state_dict': self.optimizer.state_dict(),
'scheduler_state_dict': self.scheduler.state_dict() if self.scheduler else None,
'global_step': self.global_step,
'current_epoch': epoch,
'config': vars(self.config)
}
self.checkpoint_manager.save(state_dict, self.global_step)
except Exception as e:
logger.error(f"Failed to save checkpoint at step {self.global_step}: {e}")
# Continue training even if checkpoint fails
except Exception as e:
# Catch-all for any unexpected errors
logger.error(f"Unexpected error in training step {step}: {e}")
failed_batches.append((step, 'unknown', str(e)))
if len(failed_batches) > max_failures * 2:
raise RuntimeError(f"Training failed with too many errors: {len(failed_batches)}")
continue
# Log epoch summary
if failed_batches:
logger.warning(f"Epoch {epoch + 1} completed with {len(failed_batches)} failed batches")
for step, stage, error in failed_batches[:5]: # Show first 5 failures
logger.debug(f" Batch {step} failed at {stage}: {error}")
# Compute epoch metrics
epoch_metrics['loss'] = epoch_loss / max(num_batches - len(failed_batches), 1)
epoch_metrics['failed_batches'] = len(failed_batches)
epoch_metrics['success_rate'] = (num_batches - len(failed_batches)) / num_batches if num_batches > 0 else 0.0
return epoch_metrics
def evaluate(self, dataloader: DataLoader) -> Dict[str, float]:
"""Evaluate the model"""
self.model.eval()
total_loss = 0
num_batches = 0
all_metrics = {}
with torch.no_grad():
for batch in tqdm(dataloader, desc="Evaluating", disable=not is_main_process()):
batch = self._prepare_batch(batch)
# Compute metrics
metrics = self.evaluate_step(batch)
# Accumulate metrics
for key, value in metrics.items():
if key not in all_metrics:
all_metrics[key] = 0
all_metrics[key] += value
num_batches += 1
# Average metrics
for key in all_metrics:
all_metrics[key] /= num_batches
# Gather metrics across devices
if self.is_distributed:
all_metrics = reduce_dict(all_metrics)
return all_metrics
def _prepare_batch(self, batch: Dict[str, Any]) -> Dict[str, Any]:
"""Move tensor values in batch dict to the target device, leave non-tensors untouched."""
moved = {}
for k, v in batch.items():
if isinstance(v, torch.Tensor):
moved[k] = v.to(self.device, non_blocking=True)
else:
moved[k] = v # keep lists / strings / ints as-is
return moved
def _clear_device_cache(self):
"""Clear device memory cache with support for CUDA and Ascend NPU"""
try:
if torch.cuda.is_available():
torch.cuda.empty_cache()
else:
# Try Ascend support if available
try:
from utils.ascend_support import clear_ascend_cache
clear_ascend_cache()
except ImportError:
# Fallback to manual NPU check
try:
npu = getattr(torch, 'npu', None)
if npu is not None and hasattr(npu, 'is_available') and npu.is_available():
npu.empty_cache()
except Exception:
pass
except Exception as e:
logger.debug(f"Failed to clear device cache: {e}")
def _is_better_metric(self, current: float) -> bool:
"""Check if current metric is better than best"""
if self.config.greater_is_better:
return current > self.best_metric
else:
return current < self.best_metric
def _save_checkpoint(self, is_best: bool = False, metrics: Optional[Dict[str, float]] = None):
"""Save model checkpoint"""
if not is_main_process() or not self.checkpoint_manager:
return
checkpoint = {
'model_state_dict': self.model.module.state_dict() if self.is_distributed else self.model.state_dict(), # type: ignore[attr-defined]
'optimizer_state_dict': self.optimizer.state_dict(),
'scheduler_state_dict': self.scheduler.state_dict() if self.scheduler else None,
'config': self.config.to_dict(),
'global_step': self.global_step,
'current_epoch': self.current_epoch,
'best_metric': self.best_metric
}
if self.scaler:
checkpoint['scaler_state_dict'] = self.scaler.state_dict()
self.checkpoint_manager.save(
checkpoint,
step=self.global_step,
is_best=is_best,
metrics=metrics
)
def _save_training_summary(self):
"""Save training summary"""
if not is_main_process():
return
summary = {
'config': self.config.to_dict(),
'final_metrics': self.training_history[-1] if self.training_history else {},
'best_metric': self.best_metric,
'total_steps': self.global_step,
'total_epochs': self.current_epoch,
'training_time': datetime.now().isoformat()
}
summary_path = os.path.join(self.config.output_dir, 'training_summary.json')
with open(summary_path, 'w') as f:
json.dump(summary, f, indent=2)
def load_checkpoint(self, checkpoint_path: str):
"""Load checkpoint"""
logger.info(f"Loading checkpoint from {checkpoint_path}")
checkpoint = torch.load(checkpoint_path, map_location=self.device)
# Load model state
if self.is_distributed:
self.model.module.load_state_dict(checkpoint['model_state_dict']) # type: ignore[attr-defined]
else:
self.model.load_state_dict(checkpoint['model_state_dict']) # type: ignore[attr-defined]
# Load optimizer state
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
# Load scheduler state
if self.scheduler and checkpoint.get('scheduler_state_dict'):
self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
# Load scaler state
if self.scaler and checkpoint.get('scaler_state_dict'):
self.scaler.load_state_dict(checkpoint['scaler_state_dict'])
# Load training state
self.global_step = checkpoint.get('global_step', 0)
self.current_epoch = checkpoint.get('current_epoch', 0)
self.best_metric = checkpoint.get('best_metric', float('-inf'))
logger.info(f"Resumed from step {self.global_step}, epoch {self.current_epoch}")