init commit
This commit is contained in:
40
training/__init__.py
Normal file
40
training/__init__.py
Normal file
@@ -0,0 +1,40 @@
|
||||
"""
|
||||
Training modules for BGE models
|
||||
"""
|
||||
|
||||
def get_trainer(trainer_type: str):
|
||||
"""Lazy loading of trainers to avoid importing all at once"""
|
||||
if trainer_type == 'base':
|
||||
from .trainer import BaseTrainer
|
||||
return BaseTrainer
|
||||
elif trainer_type == 'm3':
|
||||
from .m3_trainer import BGEM3Trainer
|
||||
return BGEM3Trainer
|
||||
elif trainer_type == 'reranker':
|
||||
from .reranker_trainer import BGERerankerTrainer
|
||||
return BGERerankerTrainer
|
||||
elif trainer_type == 'joint':
|
||||
from .joint_trainer import JointTrainer
|
||||
return JointTrainer
|
||||
else:
|
||||
raise ValueError(f"Unknown trainer type: {trainer_type}")
|
||||
|
||||
# For backward compatibility, import the commonly used ones
|
||||
from .trainer import BaseTrainer
|
||||
from .m3_trainer import BGEM3Trainer
|
||||
|
||||
# Only import reranker and joint trainers when explicitly requested
|
||||
def __getattr__(name):
|
||||
if name == 'BGERerankerTrainer':
|
||||
from .reranker_trainer import BGERerankerTrainer
|
||||
return BGERerankerTrainer
|
||||
elif name == 'JointTrainer':
|
||||
from .joint_trainer import JointTrainer
|
||||
return JointTrainer
|
||||
raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
|
||||
|
||||
__all__ = [
|
||||
'BaseTrainer',
|
||||
'BGEM3Trainer',
|
||||
'get_trainer'
|
||||
]
|
||||
419
training/joint_trainer.py
Normal file
419
training/joint_trainer.py
Normal file
@@ -0,0 +1,419 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.optim import AdamW
|
||||
from typing import Dict, Optional, List, Tuple
|
||||
import logging
|
||||
from tqdm import tqdm
|
||||
import numpy as np
|
||||
import os
|
||||
|
||||
from .trainer import BaseTrainer
|
||||
from models.bge_m3 import BGEM3Model
|
||||
from models.bge_reranker import BGERerankerModel
|
||||
from models.losses import (
|
||||
InfoNCELoss, ListwiseCrossEntropyLoss,
|
||||
DynamicListwiseDistillationLoss, MultiTaskLoss
|
||||
)
|
||||
from data.dataset import JointDataset
|
||||
from config.base_config import BaseConfig
|
||||
from utils.distributed import is_main_process, get_world_size, reduce_dict
|
||||
from utils.logging import get_logger
|
||||
|
||||
# Import autocast and GradScaler with fallback
|
||||
from torch.cuda.amp.autocast_mode import autocast
|
||||
from torch.cuda.amp.grad_scaler import GradScaler
|
||||
# Scheduler import
|
||||
try:
|
||||
from transformers import get_linear_schedule_with_warmup
|
||||
except ImportError:
|
||||
get_linear_schedule_with_warmup = None
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class JointTrainer(BaseTrainer):
|
||||
"""
|
||||
Joint trainer for BGE-M3 and BGE-Reranker using RocketQAv2 approach
|
||||
|
||||
Notes:
|
||||
- All parameters (config, retriever, reranker, optimizers, device, etc.) should be passed from config-driven training scripts.
|
||||
- This class does not load config directly; pass config values from your main script for consistency.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: BaseConfig,
|
||||
retriever: BGEM3Model,
|
||||
reranker: BGERerankerModel,
|
||||
retriever_optimizer: Optional[torch.optim.Optimizer] = None,
|
||||
reranker_optimizer: Optional[torch.optim.Optimizer] = None,
|
||||
device: Optional[torch.device] = None
|
||||
):
|
||||
# Create default optimizer if none provided
|
||||
if retriever_optimizer is None:
|
||||
retriever_optimizer = AdamW(
|
||||
retriever.parameters(),
|
||||
lr=config.learning_rate,
|
||||
weight_decay=config.weight_decay
|
||||
)
|
||||
|
||||
# We'll use the retriever as the main model for the base class
|
||||
super().__init__(config, retriever, retriever_optimizer, device=device)
|
||||
|
||||
self.retriever = self.model # Alias for clarity
|
||||
self.reranker = reranker
|
||||
self.reranker.to(self.device)
|
||||
|
||||
# Setup reranker optimizer if not provided
|
||||
if reranker_optimizer is None:
|
||||
self.reranker_optimizer = AdamW(
|
||||
self.reranker.parameters(),
|
||||
lr=config.learning_rate * 2, # Higher LR for reranker
|
||||
weight_decay=config.weight_decay
|
||||
)
|
||||
else:
|
||||
self.reranker_optimizer = reranker_optimizer
|
||||
|
||||
# Initialize losses
|
||||
self.retriever_loss_fn = InfoNCELoss(
|
||||
temperature=config.temperature if hasattr(config, 'temperature') else 0.02,
|
||||
use_in_batch_negatives=True # type: ignore[call-arg]
|
||||
)
|
||||
|
||||
self.reranker_loss_fn = ListwiseCrossEntropyLoss(
|
||||
temperature=1.0,
|
||||
label_smoothing=0.0 # type: ignore[call-arg]
|
||||
)
|
||||
|
||||
self.distillation_loss_fn = DynamicListwiseDistillationLoss(
|
||||
temperature=1.0,
|
||||
alpha=0.5 # type: ignore[call-arg]
|
||||
)
|
||||
|
||||
# Joint training specific parameters
|
||||
self.update_retriever_interval = getattr(config, 'update_retriever_steps', 100)
|
||||
self.joint_loss_weight = getattr(config, 'joint_loss_weight', 0.3)
|
||||
self.use_dynamic_distillation = getattr(config, 'use_dynamic_listwise_distillation', True)
|
||||
|
||||
# Tracking metrics
|
||||
self.retriever_metrics = []
|
||||
self.reranker_metrics = []
|
||||
|
||||
def compute_loss(self, batch: Tuple[Dict, Dict]) -> torch.Tensor:
|
||||
"""
|
||||
Compute joint loss for both retriever and reranker
|
||||
|
||||
Args:
|
||||
batch: Tuple of (retriever_batch, reranker_batch)
|
||||
"""
|
||||
retriever_batch, reranker_batch = batch
|
||||
|
||||
# 1. Compute retriever loss
|
||||
retriever_loss = self._compute_retriever_loss(retriever_batch)
|
||||
|
||||
# 2. Compute reranker loss
|
||||
reranker_loss = self._compute_reranker_loss(reranker_batch)
|
||||
|
||||
# 3. Compute distillation losses if enabled
|
||||
if self.use_dynamic_distillation:
|
||||
# Get scores from both models
|
||||
with torch.no_grad():
|
||||
retriever_scores = self._get_retriever_scores(retriever_batch)
|
||||
reranker_scores = self._get_reranker_scores(reranker_batch)
|
||||
|
||||
# Compute mutual distillation
|
||||
retriever_distill_loss, reranker_distill_loss = self.distillation_loss_fn(
|
||||
retriever_scores,
|
||||
reranker_scores,
|
||||
labels=reranker_batch.get('labels')
|
||||
)
|
||||
|
||||
# Add distillation to main losses
|
||||
retriever_loss = retriever_loss + self.joint_loss_weight * retriever_distill_loss
|
||||
reranker_loss = reranker_loss + self.joint_loss_weight * reranker_distill_loss
|
||||
|
||||
# Combine losses (we'll handle separate backward passes)
|
||||
self.current_retriever_loss = retriever_loss
|
||||
self.current_reranker_loss = reranker_loss
|
||||
|
||||
# Return combined loss for logging
|
||||
return retriever_loss + reranker_loss
|
||||
|
||||
def _compute_retriever_loss(self, batch: Dict) -> torch.Tensor:
|
||||
"""Compute retriever (bi-encoder) loss"""
|
||||
# Get embeddings
|
||||
query_outputs = self.retriever(
|
||||
input_ids=batch['query_input_ids'],
|
||||
attention_mask=batch['query_attention_mask'],
|
||||
return_dense=True
|
||||
)
|
||||
|
||||
passage_outputs = self.retriever(
|
||||
input_ids=batch['passage_input_ids'],
|
||||
attention_mask=batch['passage_attention_mask'],
|
||||
return_dense=True
|
||||
)
|
||||
|
||||
# Compute contrastive loss
|
||||
loss = self.retriever_loss_fn(
|
||||
query_outputs['dense'],
|
||||
passage_outputs['dense'],
|
||||
labels=batch.get('labels')
|
||||
)
|
||||
|
||||
return loss
|
||||
|
||||
def _compute_reranker_loss(self, batch: Dict) -> torch.Tensor:
|
||||
"""Compute reranker (cross-encoder) loss"""
|
||||
outputs = self.reranker(
|
||||
input_ids=batch['input_ids'],
|
||||
attention_mask=batch['attention_mask']
|
||||
)
|
||||
|
||||
logits = outputs['logits']
|
||||
labels = batch['labels']
|
||||
|
||||
# Reshape for listwise loss if needed
|
||||
if logits.dim() == 1:
|
||||
# Binary classification
|
||||
loss = nn.functional.binary_cross_entropy_with_logits(logits, labels.float())
|
||||
else:
|
||||
# Listwise ranking
|
||||
loss = self.reranker_loss_fn(logits, labels)
|
||||
|
||||
return loss
|
||||
|
||||
def _get_retriever_scores(self, batch: Dict) -> torch.Tensor:
|
||||
"""Get retriever similarity scores"""
|
||||
with torch.no_grad():
|
||||
query_outputs = self.retriever(
|
||||
input_ids=batch['query_input_ids'],
|
||||
attention_mask=batch['query_attention_mask'],
|
||||
return_dense=True
|
||||
)
|
||||
|
||||
passage_outputs = self.retriever(
|
||||
input_ids=batch['passage_input_ids'],
|
||||
attention_mask=batch['passage_attention_mask'],
|
||||
return_dense=True
|
||||
)
|
||||
|
||||
# Compute similarities
|
||||
scores = torch.matmul(
|
||||
query_outputs['dense'],
|
||||
passage_outputs['dense'].transpose(-1, -2)
|
||||
)
|
||||
|
||||
return scores
|
||||
|
||||
def _get_reranker_scores(self, batch: Dict) -> torch.Tensor:
|
||||
"""Get reranker scores"""
|
||||
outputs = self.reranker(
|
||||
input_ids=batch['input_ids'],
|
||||
attention_mask=batch['attention_mask']
|
||||
)
|
||||
return outputs['logits']
|
||||
|
||||
def _train_epoch(self, dataloader: DataLoader, epoch: int) -> Dict[str, float]:
|
||||
"""Train both models for one epoch"""
|
||||
self.retriever.train()
|
||||
self.reranker.train()
|
||||
|
||||
epoch_metrics = {
|
||||
'retriever_loss': 0.0,
|
||||
'reranker_loss': 0.0,
|
||||
'total_loss': 0.0
|
||||
}
|
||||
num_batches = 0
|
||||
|
||||
# Set epoch for distributed sampler
|
||||
if hasattr(dataloader, 'sampler') and hasattr(dataloader.sampler, 'set_epoch'):
|
||||
dataloader.sampler.set_epoch(epoch) # type: ignore[attr-defined]
|
||||
|
||||
progress_bar = tqdm(
|
||||
dataloader,
|
||||
desc=f"Joint Training Epoch {epoch + 1}",
|
||||
disable=not is_main_process()
|
||||
)
|
||||
|
||||
for step, batch in enumerate(progress_bar):
|
||||
# Prepare batch
|
||||
retriever_batch, reranker_batch = batch
|
||||
retriever_batch = self._prepare_batch(retriever_batch)
|
||||
reranker_batch = self._prepare_batch(reranker_batch)
|
||||
batch = (retriever_batch, reranker_batch)
|
||||
|
||||
# Forward pass
|
||||
total_loss = self.compute_loss(batch)
|
||||
|
||||
# Scale losses for gradient accumulation
|
||||
retriever_loss = self.current_retriever_loss / self.config.gradient_accumulation_steps
|
||||
reranker_loss = self.current_reranker_loss / self.config.gradient_accumulation_steps
|
||||
|
||||
# Backward passes (separate for each model)
|
||||
if self.scaler:
|
||||
# Retriever backward
|
||||
self.scaler.scale(retriever_loss).backward(retain_graph=True) # type: ignore[attr-defined]
|
||||
# Reranker backward
|
||||
self.scaler.scale(reranker_loss).backward() # type: ignore[attr-defined]
|
||||
else:
|
||||
retriever_loss.backward(retain_graph=True)
|
||||
reranker_loss.backward()
|
||||
|
||||
# Update weights
|
||||
if (step + 1) % self.config.gradient_accumulation_steps == 0:
|
||||
# Gradient clipping
|
||||
if self.scaler:
|
||||
self.scaler.unscale_(self.optimizer)
|
||||
self.scaler.unscale_(self.reranker_optimizer)
|
||||
|
||||
torch.nn.utils.clip_grad_norm_( # type: ignore[attr-defined]
|
||||
self.retriever.parameters(),
|
||||
self.config.max_grad_norm
|
||||
)
|
||||
torch.nn.utils.clip_grad_norm_( # type: ignore[attr-defined]
|
||||
self.reranker.parameters(),
|
||||
self.config.max_grad_norm
|
||||
)
|
||||
|
||||
# Optimizer steps
|
||||
if self.scaler:
|
||||
self.scaler.step(self.optimizer)
|
||||
self.scaler.step(self.reranker_optimizer)
|
||||
self.scaler.update()
|
||||
else:
|
||||
self.optimizer.step()
|
||||
self.reranker_optimizer.step()
|
||||
|
||||
# Scheduler steps
|
||||
if self.scheduler:
|
||||
self.scheduler.step()
|
||||
# Note: Add reranker scheduler if needed
|
||||
|
||||
# Zero gradients
|
||||
self.optimizer.zero_grad()
|
||||
self.reranker_optimizer.zero_grad()
|
||||
|
||||
self.global_step += 1
|
||||
|
||||
# Update hard negatives periodically
|
||||
if self.global_step % self.update_retriever_interval == 0:
|
||||
self._update_hard_negatives()
|
||||
|
||||
# Logging
|
||||
if self.global_step % self.config.logging_steps == 0:
|
||||
step_metrics = {
|
||||
'retriever_loss': retriever_loss.item() * self.config.gradient_accumulation_steps,
|
||||
'reranker_loss': reranker_loss.item() * self.config.gradient_accumulation_steps,
|
||||
'total_loss': (
|
||||
retriever_loss.item() + reranker_loss.item()) * self.config.gradient_accumulation_steps,
|
||||
'retriever_lr': self.optimizer.param_groups[0]['lr'],
|
||||
'reranker_lr': self.reranker_optimizer.param_groups[0]['lr']
|
||||
}
|
||||
|
||||
if self.is_distributed:
|
||||
step_metrics = reduce_dict(step_metrics)
|
||||
|
||||
progress_bar.set_postfix({
|
||||
'ret_loss': f"{step_metrics['retriever_loss']:.4f}",
|
||||
'rank_loss': f"{step_metrics['reranker_loss']:.4f}"
|
||||
})
|
||||
|
||||
if is_main_process():
|
||||
for key, value in step_metrics.items():
|
||||
self.tb_logger.log_scalar(f"joint/{key}", value, self.global_step)
|
||||
# Flush logs periodically
|
||||
if self.global_step % (self.config.logging_steps * 10) == 0:
|
||||
self.tb_logger.flush()
|
||||
|
||||
# Accumulate metrics
|
||||
epoch_metrics['retriever_loss'] += retriever_loss.item()
|
||||
epoch_metrics['reranker_loss'] += reranker_loss.item()
|
||||
epoch_metrics['total_loss'] += (retriever_loss.item() + reranker_loss.item())
|
||||
num_batches += 1
|
||||
|
||||
# Average metrics
|
||||
for key in epoch_metrics:
|
||||
epoch_metrics[key] = epoch_metrics[key] / num_batches * self.config.gradient_accumulation_steps
|
||||
|
||||
if self.is_distributed:
|
||||
epoch_metrics = reduce_dict(epoch_metrics)
|
||||
|
||||
return epoch_metrics
|
||||
|
||||
def evaluate_step(self, batch: Tuple[Dict, Dict]) -> Dict[str, float]:
|
||||
"""Evaluate both models on a batch"""
|
||||
retriever_batch, reranker_batch = batch
|
||||
|
||||
metrics = {}
|
||||
|
||||
# Evaluate retriever
|
||||
with torch.no_grad():
|
||||
retriever_loss = self._compute_retriever_loss(retriever_batch)
|
||||
metrics['retriever_loss'] = retriever_loss.item()
|
||||
|
||||
# Compute retrieval metrics
|
||||
retriever_scores = self._get_retriever_scores(retriever_batch)
|
||||
labels = retriever_batch['labels']
|
||||
|
||||
# Simple accuracy: is the positive passage ranked first?
|
||||
top_indices = retriever_scores.argmax(dim=-1)
|
||||
correct = (top_indices == labels.argmax(dim=-1)).float().mean()
|
||||
metrics['retriever_accuracy'] = correct.item()
|
||||
|
||||
# Evaluate reranker
|
||||
with torch.no_grad():
|
||||
reranker_loss = self._compute_reranker_loss(reranker_batch)
|
||||
metrics['reranker_loss'] = reranker_loss.item()
|
||||
|
||||
# Compute reranking metrics
|
||||
reranker_scores = self._get_reranker_scores(reranker_batch)
|
||||
labels = reranker_batch['labels']
|
||||
|
||||
if reranker_scores.dim() == 1:
|
||||
# Binary accuracy
|
||||
predictions = (reranker_scores > 0).float()
|
||||
correct = (predictions == labels).float().mean()
|
||||
else:
|
||||
# Ranking accuracy
|
||||
top_indices = reranker_scores.argmax(dim=-1)
|
||||
correct = (top_indices == labels.argmax(dim=-1)).float().mean()
|
||||
|
||||
metrics['reranker_accuracy'] = correct.item()
|
||||
|
||||
metrics['total_loss'] = metrics['retriever_loss'] + metrics['reranker_loss']
|
||||
|
||||
return metrics
|
||||
|
||||
def _update_hard_negatives(self):
|
||||
"""Update hard negatives using current models (simplified version)"""
|
||||
# This would implement the dynamic hard negative mining
|
||||
# For now, just log that we would update
|
||||
if is_main_process():
|
||||
logger.info(f"Would update hard negatives at step {self.global_step}")
|
||||
|
||||
def save_models(self, output_dir: str):
|
||||
"""Save both models"""
|
||||
if not is_main_process():
|
||||
return
|
||||
|
||||
# Save retriever
|
||||
retriever_path = os.path.join(output_dir, "retriever")
|
||||
os.makedirs(retriever_path, exist_ok=True)
|
||||
|
||||
if hasattr(self.retriever, 'module'):
|
||||
self.retriever.module.save_pretrained(retriever_path) # type: ignore[attr-defined]
|
||||
else:
|
||||
self.retriever.save_pretrained(retriever_path) # type: ignore[attr-defined]
|
||||
|
||||
# Save reranker
|
||||
reranker_path = os.path.join(output_dir, "reranker")
|
||||
os.makedirs(reranker_path, exist_ok=True)
|
||||
|
||||
if hasattr(self.reranker, 'module'):
|
||||
self.reranker.module.save_pretrained(reranker_path) # type: ignore[attr-defined]
|
||||
else:
|
||||
self.reranker.save_pretrained(reranker_path) # type: ignore[attr-defined]
|
||||
|
||||
logger.info(f"Saved models to {output_dir}")
|
||||
894
training/m3_trainer.py
Normal file
894
training/m3_trainer.py
Normal file
@@ -0,0 +1,894 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.optim import AdamW
|
||||
from transformers import AutoTokenizer
|
||||
import os
|
||||
import json
|
||||
import logging
|
||||
from tqdm import tqdm
|
||||
from typing import Dict, Optional, List, Tuple
|
||||
import numpy as np
|
||||
from torch.cuda.amp.autocast_mode import autocast
|
||||
from torch.cuda.amp.grad_scaler import GradScaler
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
import torch.distributed as dist
|
||||
|
||||
from .trainer import BaseTrainer
|
||||
from models.bge_m3 import BGEM3Model
|
||||
from models.losses import InfoNCELoss, M3KnowledgeDistillationLoss, MultiTaskLoss
|
||||
from data.dataset import BGEM3Dataset
|
||||
|
||||
# Optional hard negative mining
|
||||
try:
|
||||
from data.hard_negative_mining import ANCEHardNegativeMiner
|
||||
ANCE_AVAILABLE = True
|
||||
except ImportError:
|
||||
ANCE_AVAILABLE = False
|
||||
ANCEHardNegativeMiner = None
|
||||
|
||||
from utils.checkpoint import CheckpointManager
|
||||
from utils.logging import get_logger
|
||||
from evaluation.metrics import compute_retrieval_metrics
|
||||
from config.m3_config import BGEM3Config
|
||||
|
||||
# Import scheduler with fallback
|
||||
try:
|
||||
from transformers import get_linear_schedule_with_warmup
|
||||
except ImportError:
|
||||
get_linear_schedule_with_warmup = None
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class BGEM3Trainer(BaseTrainer):
|
||||
"""
|
||||
Trainer for BGE-M3 embedding model with all SOTA techniques
|
||||
|
||||
Notes:
|
||||
- All parameters (config, model, tokenizer, etc.) should be passed from config-driven training scripts.
|
||||
- This class does not load config directly; pass config values from your main script for consistency.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: BGEM3Config,
|
||||
model: Optional[BGEM3Model] = None,
|
||||
tokenizer: Optional[AutoTokenizer] = None,
|
||||
optimizer: Optional[torch.optim.Optimizer] = None,
|
||||
scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
|
||||
device: Optional[torch.device] = None
|
||||
):
|
||||
# Store config first
|
||||
self.config = config
|
||||
|
||||
# Initialize tokenizer first
|
||||
if tokenizer is None:
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(
|
||||
config.model_name_or_path,
|
||||
cache_dir=config.cache_dir,
|
||||
trust_remote_code=True
|
||||
)
|
||||
else:
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
# Initialize model
|
||||
if model is None:
|
||||
self.model = BGEM3Model(
|
||||
model_name_or_path=config.model_name_or_path,
|
||||
use_dense=config.use_dense,
|
||||
use_sparse=config.use_sparse,
|
||||
use_colbert=config.use_colbert,
|
||||
pooling_method=config.sentence_pooling_method,
|
||||
normalize_embeddings=config.normalize_embeddings,
|
||||
colbert_dim=config.colbert_dim,
|
||||
sparse_top_k=config.sparse_top_k,
|
||||
cache_dir=config.cache_dir
|
||||
)
|
||||
else:
|
||||
self.model = model
|
||||
|
||||
# Initialize optimizer if not provided
|
||||
if optimizer is None:
|
||||
optimizer = self._create_optimizer()
|
||||
|
||||
# Initialize scheduler if not provided and get_linear_schedule_with_warmup is available
|
||||
if scheduler is None and get_linear_schedule_with_warmup is not None:
|
||||
# We'll create it later when we know the total training steps
|
||||
pass
|
||||
|
||||
# Initialize parent class
|
||||
super().__init__(config, self.model, optimizer, scheduler, device)
|
||||
|
||||
# Initialize losses
|
||||
self.infonce_loss = InfoNCELoss(
|
||||
temperature=config.temperature,
|
||||
reduction='mean'
|
||||
)
|
||||
|
||||
self.kd_loss = M3KnowledgeDistillationLoss(
|
||||
temperature=config.self_distill_temperature,
|
||||
alpha=config.self_distill_alpha,
|
||||
distill_types=['dense', 'sparse', 'colbert']
|
||||
)
|
||||
|
||||
self.multi_task_loss = MultiTaskLoss(
|
||||
loss_weights={
|
||||
'dense': config.dense_weight,
|
||||
'sparse': config.sparse_weight,
|
||||
'colbert': config.colbert_weight,
|
||||
'distillation': config.distillation_weight
|
||||
}
|
||||
)
|
||||
|
||||
# Initialize hard negative miner
|
||||
if config.use_hard_negatives and ANCE_AVAILABLE:
|
||||
self.hard_negative_miner = ANCEHardNegativeMiner(
|
||||
embedding_dim=self.model.config.hidden_size, # type: ignore[attr-defined]
|
||||
num_hard_negatives=config.num_hard_negatives,
|
||||
negative_range=config.ance_negative_range,
|
||||
margin=config.ance_margin,
|
||||
index_refresh_interval=config.hard_negative_mining_steps,
|
||||
use_gpu=str(self.device) == "cuda"
|
||||
)
|
||||
self.hard_negative_miner.start_async_updates()
|
||||
else:
|
||||
if config.use_hard_negatives and not ANCE_AVAILABLE:
|
||||
logger.warning("Hard negative mining requested but ANCEHardNegativeMiner not available")
|
||||
self.hard_negative_miner = None
|
||||
|
||||
logger.info(f"Initialized BGE-M3 trainer with config: {config.to_dict()}")
|
||||
|
||||
def _create_optimizer(self) -> torch.optim.Optimizer:
|
||||
"""Create optimizer with parameter groups"""
|
||||
# Separate parameters for different learning rates if needed
|
||||
param_groups = [
|
||||
{
|
||||
'params': [p for n, p in self.model.named_parameters()
|
||||
if 'colbert_linear' not in n and 'sparse_linear' not in n],
|
||||
'lr': self.config.learning_rate
|
||||
}
|
||||
]
|
||||
|
||||
# Different LR for task-specific layers
|
||||
if self.config.use_colbert and hasattr(self.model, 'colbert_linear'):
|
||||
param_groups.append({
|
||||
'params': self.model.colbert_linear.parameters(),
|
||||
'lr': self.config.learning_rate * 2 # Higher LR for new layers
|
||||
})
|
||||
|
||||
if self.config.use_sparse and hasattr(self.model, 'sparse_linear'):
|
||||
param_groups.append({
|
||||
'params': self.model.sparse_linear.parameters(),
|
||||
'lr': self.config.learning_rate * 2
|
||||
})
|
||||
|
||||
optimizer = AdamW(
|
||||
param_groups,
|
||||
eps=self.config.adam_epsilon,
|
||||
weight_decay=self.config.weight_decay,
|
||||
betas=(self.config.adam_beta1, self.config.adam_beta2)
|
||||
)
|
||||
|
||||
return optimizer
|
||||
|
||||
def create_scheduler(self, num_training_steps: int):
|
||||
"""Create learning rate scheduler"""
|
||||
if get_linear_schedule_with_warmup is None:
|
||||
logger.warning("transformers not available, using constant learning rate")
|
||||
return None
|
||||
|
||||
num_warmup_steps = self.config.warmup_steps or int(num_training_steps * self.config.warmup_ratio)
|
||||
|
||||
self.scheduler = get_linear_schedule_with_warmup(
|
||||
self.optimizer,
|
||||
num_warmup_steps=num_warmup_steps,
|
||||
num_training_steps=num_training_steps
|
||||
)
|
||||
return self.scheduler
|
||||
|
||||
def compute_loss(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
|
||||
"""Compute M3 loss with multiple objectives"""
|
||||
|
||||
# Update hard negative mining if hook is available
|
||||
if hasattr(self, '_mining_hook'):
|
||||
self._mining_hook()
|
||||
|
||||
# Detect batch format and handle accordingly
|
||||
if 'pos_input_ids' in batch and 'neg_input_ids' in batch:
|
||||
# This is triplet format from FastM3Dataset
|
||||
return self._compute_triplet_loss(batch)
|
||||
elif 'passage_input_ids' in batch:
|
||||
# This is standard embedding format
|
||||
return self._compute_embedding_loss(batch)
|
||||
else:
|
||||
raise ValueError(f"Unknown batch format. Keys: {list(batch.keys())}")
|
||||
|
||||
def _compute_triplet_loss(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
|
||||
"""Compute loss for triplet format data (FastM3Dataset)"""
|
||||
# Get query embeddings
|
||||
query_outputs = self.model.forward(
|
||||
input_ids=batch['query_input_ids'],
|
||||
attention_mask=batch['query_attention_mask'],
|
||||
return_dense=self.config.use_dense,
|
||||
return_sparse=self.config.use_sparse,
|
||||
return_colbert=self.config.use_colbert,
|
||||
return_dict=True
|
||||
)
|
||||
|
||||
# Get positive embeddings
|
||||
pos_outputs = self.model.forward(
|
||||
input_ids=batch['pos_input_ids'],
|
||||
attention_mask=batch['pos_attention_mask'],
|
||||
return_dense=self.config.use_dense,
|
||||
return_sparse=self.config.use_sparse,
|
||||
return_colbert=self.config.use_colbert,
|
||||
return_dict=True
|
||||
)
|
||||
|
||||
# Get negative embeddings
|
||||
neg_outputs = self.model.forward(
|
||||
input_ids=batch['neg_input_ids'],
|
||||
attention_mask=batch['neg_attention_mask'],
|
||||
return_dense=self.config.use_dense,
|
||||
return_sparse=self.config.use_sparse,
|
||||
return_colbert=self.config.use_colbert,
|
||||
return_dict=True
|
||||
)
|
||||
|
||||
losses = {}
|
||||
|
||||
# Dense loss (InfoNCE with explicit pos/neg)
|
||||
if 'dense' in query_outputs:
|
||||
query_dense = query_outputs['dense'].float() # type: ignore[index]
|
||||
pos_dense = pos_outputs['dense'].float() # type: ignore[index]
|
||||
neg_dense = neg_outputs['dense'].float() # type: ignore[index]
|
||||
|
||||
# For triplet format, reshape negative embeddings to expected 3D format
|
||||
# InfoNCE expects negative_embeddings: [batch_size, num_negatives, embedding_dim]
|
||||
# But we have single negatives: [batch_size, embedding_dim]
|
||||
# Reshape to [batch_size, 1, embedding_dim]
|
||||
neg_dense_3d = neg_dense.unsqueeze(1) # [batch_size, 1, embedding_dim]
|
||||
|
||||
dense_loss = self.infonce_loss(
|
||||
query_dense,
|
||||
pos_dense,
|
||||
neg_dense_3d # Now properly shaped 3D tensor
|
||||
)
|
||||
losses['dense'] = dense_loss
|
||||
|
||||
# Sparse loss
|
||||
if self.config.use_sparse and 'sparse' in query_outputs:
|
||||
from models.losses import SparseLoss
|
||||
sparse_loss_fn = SparseLoss(
|
||||
flops_loss_weight=1e-3,
|
||||
sparse_reg_weight=1e-4
|
||||
)
|
||||
|
||||
# Combine pos and neg for sparse loss
|
||||
combined_sparse = torch.cat([pos_outputs['sparse'], neg_outputs['sparse']], dim=0) # type: ignore[index]
|
||||
query_sparse_repeated = query_outputs['sparse'].repeat(2, 1) # type: ignore[index]
|
||||
labels = torch.cat([torch.ones(len(pos_outputs['sparse'])), # type: ignore[index]
|
||||
torch.zeros(len(neg_outputs['sparse']))], dim=0).to(query_sparse_repeated.device) # type: ignore[index]
|
||||
|
||||
sparse_loss, _ = sparse_loss_fn(
|
||||
query_sparse_repeated,
|
||||
combined_sparse,
|
||||
labels
|
||||
)
|
||||
losses['sparse'] = sparse_loss
|
||||
|
||||
# ColBERT loss
|
||||
if self.config.use_colbert and 'colbert' in query_outputs:
|
||||
from models.losses import ColBERTLoss
|
||||
colbert_loss_fn = ColBERTLoss(temperature=self.config.temperature)
|
||||
|
||||
# Combine pos and neg for ColBERT loss
|
||||
combined_colbert = torch.cat([pos_outputs['colbert'], neg_outputs['colbert']], dim=0) # type: ignore[index]
|
||||
query_colbert_repeated = query_outputs['colbert'].repeat(2, 1, 1) # type: ignore[index]
|
||||
labels = torch.cat([torch.ones(len(pos_outputs['colbert'])), # type: ignore[index]
|
||||
torch.zeros(len(neg_outputs['colbert']))], dim=0).to(query_colbert_repeated.device) # type: ignore[index]
|
||||
|
||||
colbert_loss = colbert_loss_fn(
|
||||
query_colbert_repeated,
|
||||
combined_colbert,
|
||||
labels
|
||||
)
|
||||
losses['colbert'] = colbert_loss
|
||||
|
||||
# Combine losses
|
||||
total_loss, _ = self.multi_task_loss(losses)
|
||||
return total_loss
|
||||
|
||||
def _compute_embedding_loss(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
|
||||
"""Compute loss for standard embedding format data"""
|
||||
# Get query embeddings
|
||||
query_outputs = self.model.forward(
|
||||
input_ids=batch['query_input_ids'],
|
||||
attention_mask=batch['query_attention_mask'],
|
||||
return_dense=self.config.use_dense,
|
||||
return_sparse=self.config.use_sparse,
|
||||
return_colbert=self.config.use_colbert,
|
||||
return_dict=True
|
||||
)
|
||||
|
||||
# Get passage embeddings - handle both single and multiple passages
|
||||
if batch['passage_input_ids'].dim() == 3:
|
||||
# Multiple passages per query [batch, num_passages, seq_len]
|
||||
batch_size, num_passages, seq_len = batch['passage_input_ids'].shape
|
||||
passage_input_ids = batch['passage_input_ids'].view(batch_size * num_passages, seq_len)
|
||||
passage_attention_mask = batch['passage_attention_mask'].view(batch_size * num_passages, seq_len)
|
||||
multiple_passages = True
|
||||
else:
|
||||
# Single passage per query [batch, seq_len]
|
||||
passage_input_ids = batch['passage_input_ids']
|
||||
passage_attention_mask = batch['passage_attention_mask']
|
||||
batch_size = passage_input_ids.shape[0]
|
||||
num_passages = 1
|
||||
multiple_passages = False
|
||||
|
||||
passage_outputs = self.model.forward(
|
||||
input_ids=passage_input_ids,
|
||||
attention_mask=passage_attention_mask,
|
||||
return_dense=True,
|
||||
return_sparse=self.config.use_sparse,
|
||||
return_colbert=self.config.use_colbert,
|
||||
return_dict=True
|
||||
)
|
||||
|
||||
# Compute losses for different components
|
||||
losses = {}
|
||||
|
||||
# Dense loss (InfoNCE)
|
||||
if 'dense' in query_outputs and 'dense' in passage_outputs:
|
||||
# Get embeddings and ensure they're on the same device
|
||||
query_dense = query_outputs['dense'].float() # type: ignore[index] # [batch_size, embedding_dim]
|
||||
passage_dense = passage_outputs['dense'].float() # type: ignore[index] # [batch_size * num_passages, embedding_dim]
|
||||
|
||||
# Handle multiple passages per query differently for contrastive learning
|
||||
if multiple_passages:
|
||||
# Reshape passage embeddings back to [batch_size, num_passages, embedding_dim]
|
||||
passage_dense = passage_dense.view(batch_size, num_passages, -1)
|
||||
|
||||
# For contrastive learning, we need to identify positive passages
|
||||
# Use the labels to find the first positive passage for each query
|
||||
labels = batch.get('labels') # [batch_size, num_passages]
|
||||
|
||||
if labels is not None:
|
||||
# Find the first positive passage for each query
|
||||
positive_indices = []
|
||||
negative_passages = []
|
||||
|
||||
for i in range(batch_size):
|
||||
query_labels = labels[i] # [num_passages]
|
||||
positive_mask = query_labels > 0
|
||||
|
||||
if positive_mask.any():
|
||||
# Get first positive passage
|
||||
pos_idx = positive_mask.nonzero(as_tuple=True)[0][0].item()
|
||||
positive_indices.append(pos_idx)
|
||||
|
||||
# Collect negative passages
|
||||
neg_mask = ~positive_mask
|
||||
if neg_mask.any():
|
||||
neg_embeddings = passage_dense[i][neg_mask] # [num_negatives, embedding_dim]
|
||||
negative_passages.append(neg_embeddings)
|
||||
else:
|
||||
negative_passages.append(None)
|
||||
else:
|
||||
# No positive passage, use first passage as positive (fallback)
|
||||
positive_indices.append(0)
|
||||
negative_passages.append(passage_dense[i][1:])
|
||||
|
||||
# Extract positive embeddings
|
||||
positive_embeddings = torch.stack([
|
||||
passage_dense[i, positive_indices[i]] for i in range(batch_size)
|
||||
]) # [batch_size, embedding_dim]
|
||||
|
||||
# Stack negative embeddings (if available)
|
||||
max_neg = max(neg.shape[0] if neg is not None else 0 for neg in negative_passages)
|
||||
if max_neg > 0:
|
||||
# Pad negative embeddings to same size
|
||||
padded_negatives = []
|
||||
for neg in negative_passages:
|
||||
if neg is not None and neg.shape[0] > 0:
|
||||
if neg.shape[0] < max_neg:
|
||||
# Pad with zeros
|
||||
padding = torch.zeros(max_neg - neg.shape[0], neg.shape[1], device=neg.device)
|
||||
neg = torch.cat([neg, padding], dim=0)
|
||||
padded_negatives.append(neg)
|
||||
else:
|
||||
# No negatives, create zero tensor
|
||||
padded_negatives.append(torch.zeros(max_neg, passage_dense.shape[-1], device=passage_dense.device))
|
||||
|
||||
negative_embeddings = torch.stack(padded_negatives) # [batch_size, max_neg, embedding_dim]
|
||||
else:
|
||||
negative_embeddings = None
|
||||
|
||||
# Compute InfoNCE loss with explicit positives and negatives
|
||||
dense_loss = self.infonce_loss(
|
||||
query_dense,
|
||||
positive_embeddings,
|
||||
negative_embeddings
|
||||
)
|
||||
else:
|
||||
# No labels provided, use first passage as positive, rest as negatives
|
||||
positive_embeddings = passage_dense[:, 0, :] # [batch_size, embedding_dim]
|
||||
if num_passages > 1:
|
||||
negative_embeddings = passage_dense[:, 1:, :] # [batch_size, num_passages-1, embedding_dim]
|
||||
else:
|
||||
negative_embeddings = None
|
||||
|
||||
dense_loss = self.infonce_loss(
|
||||
query_dense,
|
||||
positive_embeddings,
|
||||
negative_embeddings
|
||||
)
|
||||
else:
|
||||
# Single passage per query - use standard in-batch negatives
|
||||
# InfoNCE will create contrastive pairs using in-batch negatives
|
||||
dense_loss = self.infonce_loss(
|
||||
query_dense,
|
||||
passage_dense,
|
||||
None # No explicit negatives, will use in-batch
|
||||
)
|
||||
|
||||
losses['dense'] = dense_loss
|
||||
|
||||
# Sparse loss (if using sparse)
|
||||
if self.config.use_sparse and 'sparse' in query_outputs and 'sparse' in passage_outputs:
|
||||
from models.losses import SparseLoss
|
||||
sparse_loss_fn = SparseLoss(
|
||||
flops_loss_weight=1e-3,
|
||||
sparse_reg_weight=1e-4
|
||||
)
|
||||
|
||||
query_sparse = query_outputs['sparse'] # type: ignore[index] # [batch_size, query_seq_len]
|
||||
passage_sparse = passage_outputs['sparse'] # type: ignore[index] # [batch_size * num_passages, doc_seq_len]
|
||||
|
||||
# For sparse loss, we need to handle the contrastive learning differently
|
||||
# Since sparse embeddings have different sequence lengths, we'll use the same approach as dense
|
||||
if multiple_passages:
|
||||
# Use only the first positive passage for each query
|
||||
labels = batch.get('labels')
|
||||
if labels is not None:
|
||||
positive_indices = []
|
||||
for i in range(batch_size):
|
||||
query_labels = labels[i]
|
||||
positive_mask = query_labels > 0
|
||||
if positive_mask.any():
|
||||
pos_idx = positive_mask.nonzero(as_tuple=True)[0][0].item()
|
||||
else:
|
||||
pos_idx = 0
|
||||
positive_indices.append(pos_idx + i * num_passages)
|
||||
|
||||
# Extract positive sparse embeddings
|
||||
positive_sparse = passage_sparse[positive_indices] # [batch_size, doc_seq_len]
|
||||
else:
|
||||
# Use first passage as positive
|
||||
positive_sparse = passage_sparse[::num_passages] # [batch_size, doc_seq_len]
|
||||
|
||||
# Create contrastive labels for sparse loss
|
||||
contrastive_labels = torch.arange(batch_size, device=query_sparse.device, dtype=torch.long)
|
||||
|
||||
sparse_loss, _ = sparse_loss_fn(
|
||||
query_sparse,
|
||||
positive_sparse,
|
||||
contrastive_labels
|
||||
)
|
||||
else:
|
||||
# Single passage per query
|
||||
contrastive_labels = torch.arange(batch_size, device=query_sparse.device, dtype=torch.long)
|
||||
sparse_loss, _ = sparse_loss_fn(
|
||||
query_sparse,
|
||||
passage_sparse,
|
||||
contrastive_labels
|
||||
)
|
||||
|
||||
losses['sparse'] = sparse_loss
|
||||
|
||||
# ColBERT loss (if using ColBERT)
|
||||
if self.config.use_colbert and 'colbert' in query_outputs and 'colbert' in passage_outputs:
|
||||
from models.losses import ColBERTLoss
|
||||
colbert_loss_fn = ColBERTLoss(temperature=self.config.temperature)
|
||||
|
||||
query_colbert = query_outputs['colbert'] # type: ignore[index] # [batch_size, query_seq_len, dim]
|
||||
passage_colbert = passage_outputs['colbert'] # type: ignore[index] # [batch_size * num_passages, doc_seq_len, dim]
|
||||
|
||||
# For ColBERT, we need pairwise interactions
|
||||
if multiple_passages:
|
||||
# Use only positive passages
|
||||
labels = batch.get('labels')
|
||||
if labels is not None:
|
||||
positive_indices = []
|
||||
for i in range(batch_size):
|
||||
query_labels = labels[i]
|
||||
positive_mask = query_labels > 0
|
||||
if positive_mask.any():
|
||||
pos_idx = positive_mask.nonzero(as_tuple=True)[0][0].item()
|
||||
else:
|
||||
pos_idx = 0
|
||||
positive_indices.append(pos_idx + i * num_passages)
|
||||
|
||||
positive_colbert = passage_colbert[positive_indices]
|
||||
else:
|
||||
positive_colbert = passage_colbert[::num_passages]
|
||||
|
||||
# Binary labels for ColBERT (1 for positive pairs)
|
||||
colbert_labels = torch.ones(batch_size, device=query_colbert.device)
|
||||
else:
|
||||
positive_colbert = passage_colbert
|
||||
colbert_labels = torch.ones(batch_size, device=query_colbert.device)
|
||||
|
||||
colbert_loss = colbert_loss_fn(
|
||||
query_colbert,
|
||||
positive_colbert,
|
||||
colbert_labels
|
||||
)
|
||||
losses['colbert'] = colbert_loss
|
||||
|
||||
# Knowledge distillation loss (only if explicitly enabled)
|
||||
if (getattr(self.config, 'use_self_distill', False) and
|
||||
self.global_step >= getattr(self.config, 'self_distill_start_step', 0)):
|
||||
try:
|
||||
# Compute teacher scores (weighted combination)
|
||||
teacher_scores = self._compute_teacher_scores(query_outputs, passage_outputs) # type: ignore[arg-type]
|
||||
student_scores = self.model.compute_similarity(
|
||||
query_outputs, passage_outputs, mode='dense'
|
||||
)
|
||||
|
||||
kd_loss, _ = self.kd_loss(
|
||||
{'dense': student_scores},
|
||||
{'dense': teacher_scores}
|
||||
)
|
||||
losses['kd'] = kd_loss
|
||||
except Exception as e:
|
||||
logger.warning(f"Knowledge distillation failed: {e}")
|
||||
|
||||
# Combine losses
|
||||
total_loss, _ = self.multi_task_loss(losses)
|
||||
|
||||
return total_loss
|
||||
|
||||
def _compute_teacher_scores(
|
||||
self,
|
||||
query_outputs: Dict[str, torch.Tensor],
|
||||
passage_outputs: Dict[str, torch.Tensor]
|
||||
) -> torch.Tensor:
|
||||
"""Compute teacher scores from multiple signals"""
|
||||
scores = []
|
||||
weights = []
|
||||
|
||||
# Get original batch size for reshaping
|
||||
query_batch_size = query_outputs['dense'].shape[0] if 'dense' in query_outputs else 1
|
||||
passage_batch_size = passage_outputs['dense'].shape[0] if 'dense' in passage_outputs else 1
|
||||
|
||||
# Check if we have multiple passages per query
|
||||
multiple_passages = passage_batch_size > query_batch_size
|
||||
|
||||
if multiple_passages:
|
||||
# For multiple passages, we need to compute scores only for positive pairs
|
||||
# Use only the first passage for each query as a simplified teacher score
|
||||
num_passages = passage_batch_size // query_batch_size
|
||||
|
||||
# Dense scores (use only first passage per query)
|
||||
if 'dense' in query_outputs and 'dense' in passage_outputs:
|
||||
query_dense = query_outputs['dense'] # [batch_size, dim]
|
||||
passage_dense = passage_outputs['dense'] # [batch_size * num_passages, dim]
|
||||
|
||||
# Extract first passage for each query
|
||||
positive_passages = passage_dense[::num_passages] # [batch_size, dim]
|
||||
|
||||
try:
|
||||
dense_scores = self.model.compute_similarity(
|
||||
{'dense': query_dense},
|
||||
{'dense': positive_passages},
|
||||
mode='dense'
|
||||
)
|
||||
scores.append(dense_scores)
|
||||
weights.append(self.config.dense_weight)
|
||||
except Exception as e:
|
||||
logger.warning(f"Dense similarity computation failed: {e}")
|
||||
|
||||
# Sparse scores (use only first passage per query)
|
||||
if 'sparse' in query_outputs and 'sparse' in passage_outputs:
|
||||
query_sparse = query_outputs['sparse'] # [batch_size, query_seq_len]
|
||||
passage_sparse = passage_outputs['sparse'] # [batch_size * num_passages, doc_seq_len]
|
||||
|
||||
# Extract first passage for each query
|
||||
positive_passages = passage_sparse[::num_passages] # [batch_size, doc_seq_len]
|
||||
|
||||
try:
|
||||
sparse_scores = self.model.compute_similarity(
|
||||
{'sparse': query_sparse},
|
||||
{'sparse': positive_passages},
|
||||
mode='sparse'
|
||||
)
|
||||
scores.append(sparse_scores)
|
||||
weights.append(self.config.sparse_weight)
|
||||
except Exception as e:
|
||||
logger.warning(f"Sparse similarity computation failed: {e}")
|
||||
|
||||
# ColBERT scores (use only first passage per query)
|
||||
if 'colbert' in query_outputs and 'colbert' in passage_outputs:
|
||||
query_colbert = query_outputs['colbert'] # [batch_size, query_seq_len, dim]
|
||||
passage_colbert = passage_outputs['colbert'] # [batch_size * num_passages, doc_seq_len, dim]
|
||||
|
||||
# Extract first passage for each query
|
||||
positive_passages = passage_colbert[::num_passages] # [batch_size, doc_seq_len, dim]
|
||||
|
||||
try:
|
||||
colbert_scores = self.model.compute_similarity(
|
||||
{'colbert': query_colbert},
|
||||
{'colbert': positive_passages},
|
||||
mode='colbert'
|
||||
)
|
||||
scores.append(colbert_scores)
|
||||
weights.append(self.config.colbert_weight)
|
||||
except Exception as e:
|
||||
logger.warning(f"ColBERT similarity computation failed: {e}")
|
||||
else:
|
||||
# Single passage per query - original logic
|
||||
# Dense scores
|
||||
if 'dense' in query_outputs:
|
||||
try:
|
||||
dense_scores = self.model.compute_similarity(
|
||||
{'dense': query_outputs['dense']},
|
||||
{'dense': passage_outputs['dense']},
|
||||
mode='dense'
|
||||
)
|
||||
scores.append(dense_scores)
|
||||
weights.append(self.config.dense_weight)
|
||||
except Exception as e:
|
||||
logger.warning(f"Dense similarity computation failed: {e}")
|
||||
|
||||
# Sparse scores
|
||||
if 'sparse' in query_outputs:
|
||||
try:
|
||||
sparse_scores = self.model.compute_similarity(
|
||||
{'sparse': query_outputs['sparse']},
|
||||
{'sparse': passage_outputs['sparse']},
|
||||
mode='sparse'
|
||||
)
|
||||
scores.append(sparse_scores)
|
||||
weights.append(self.config.sparse_weight)
|
||||
except Exception as e:
|
||||
logger.warning(f"Sparse similarity computation failed: {e}")
|
||||
|
||||
# ColBERT scores
|
||||
if 'colbert' in query_outputs:
|
||||
try:
|
||||
colbert_scores = self.model.compute_similarity(
|
||||
{'colbert': query_outputs['colbert']},
|
||||
{'colbert': passage_outputs['colbert']},
|
||||
mode='colbert'
|
||||
)
|
||||
scores.append(colbert_scores)
|
||||
weights.append(self.config.colbert_weight)
|
||||
except Exception as e:
|
||||
logger.warning(f"ColBERT similarity computation failed: {e}")
|
||||
|
||||
# Weighted combination
|
||||
if scores:
|
||||
weights = torch.tensor(weights, device=scores[0].device)
|
||||
weights = weights / weights.sum()
|
||||
teacher_scores = torch.stack([w * s for w, s in zip(weights, scores)]).sum(dim=0)
|
||||
return teacher_scores
|
||||
|
||||
# Fallback to zero scores if no valid similarity computation
|
||||
fallback_size = query_batch_size if not multiple_passages else query_batch_size
|
||||
return torch.zeros(fallback_size, device=query_outputs[list(query_outputs.keys())[0]].device)
|
||||
|
||||
def evaluate_step(self, batch: Dict[str, torch.Tensor]) -> Dict[str, float]:
|
||||
"""Evaluate a single batch"""
|
||||
# Get query embeddings
|
||||
query_outputs = self.model.forward(
|
||||
input_ids=batch['query_input_ids'],
|
||||
attention_mask=batch['query_attention_mask'],
|
||||
return_dense=True,
|
||||
return_dict=True
|
||||
)
|
||||
|
||||
# Get passage embeddings
|
||||
passage_outputs = self.model.forward(
|
||||
input_ids=batch['passage_input_ids'],
|
||||
attention_mask=batch['passage_attention_mask'],
|
||||
return_dense=True,
|
||||
return_dict=True
|
||||
)
|
||||
|
||||
# Compute loss
|
||||
loss = self.compute_loss(batch)
|
||||
|
||||
# Compute similarity scores
|
||||
scores = self.model.compute_similarity(
|
||||
query_outputs, passage_outputs, mode='dense'
|
||||
)
|
||||
|
||||
# Basic metrics
|
||||
labels = batch.get('labels', torch.zeros(scores.shape[0]))
|
||||
if labels.dim() > 1:
|
||||
labels = labels.view(-1)
|
||||
|
||||
predictions = (scores > 0.5).float()
|
||||
accuracy = (predictions == labels.float()).float().mean()
|
||||
|
||||
return {
|
||||
'loss': loss.item(),
|
||||
'accuracy': accuracy.item(),
|
||||
'mean_score': scores.mean().item()
|
||||
}
|
||||
|
||||
def train(
|
||||
self,
|
||||
train_dataset: BGEM3Dataset,
|
||||
eval_dataset: Optional[BGEM3Dataset] = None,
|
||||
resume_from_checkpoint: Optional[str] = None
|
||||
):
|
||||
"""
|
||||
Main training loop for M3Trainer with robust error handling, config-driven parameters, and progress tracking.
|
||||
"""
|
||||
try:
|
||||
train_dataloader = self._create_dataloader(train_dataset, is_train=True)
|
||||
eval_dataloader = self._create_dataloader(eval_dataset, is_train=False) if eval_dataset else None
|
||||
|
||||
# Calculate total training steps
|
||||
num_update_steps_per_epoch = max(1, len(train_dataloader) // self.config.gradient_accumulation_steps) # type: ignore[arg-type]
|
||||
total_training_steps = num_update_steps_per_epoch * self.config.num_train_epochs
|
||||
|
||||
# Create scheduler with total steps
|
||||
self.create_scheduler(total_training_steps)
|
||||
|
||||
if resume_from_checkpoint:
|
||||
self.load_checkpoint(resume_from_checkpoint)
|
||||
|
||||
logger.info(f"Starting training for {self.config.num_train_epochs} epochs")
|
||||
logger.info(f"Total training steps: {total_training_steps}")
|
||||
|
||||
# Use parent's train method
|
||||
super().train(train_dataloader, eval_dataloader, self.config.num_train_epochs) # type: ignore[arg-type]
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"M3 training failed: {e}")
|
||||
raise
|
||||
finally:
|
||||
# Cleanup
|
||||
if hasattr(self, 'hard_negative_miner') and self.hard_negative_miner:
|
||||
self.hard_negative_miner.stop_async_updates()
|
||||
|
||||
def _create_dataloader(
|
||||
self,
|
||||
dataset: Optional[BGEM3Dataset],
|
||||
is_train: bool = True
|
||||
) -> Optional[DataLoader]:
|
||||
"""Create data loader"""
|
||||
if dataset is None:
|
||||
return None
|
||||
|
||||
# Import both collate functions
|
||||
from data.dataset import collate_embedding_batch
|
||||
|
||||
# Choose appropriate collate function based on dataset type
|
||||
collate_fn = collate_embedding_batch
|
||||
num_workers = self.config.dataloader_num_workers
|
||||
|
||||
batch_size = self.config.per_device_train_batch_size if is_train else self.config.per_device_eval_batch_size
|
||||
|
||||
# Setup distributed sampler if needed
|
||||
sampler = None
|
||||
if self.is_distributed:
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
sampler = DistributedSampler(dataset, shuffle=is_train)
|
||||
|
||||
dataloader = DataLoader(
|
||||
dataset,
|
||||
batch_size=batch_size,
|
||||
shuffle=is_train and sampler is None,
|
||||
sampler=sampler,
|
||||
num_workers=num_workers,
|
||||
pin_memory=self.config.dataloader_pin_memory,
|
||||
drop_last=is_train and len(dataset) > batch_size, # Only drop last if we have multiple batches
|
||||
collate_fn=collate_fn
|
||||
)
|
||||
|
||||
return dataloader
|
||||
|
||||
def _train_epoch(self, dataloader: DataLoader, epoch: int) -> Dict[str, float]:
|
||||
"""Train one epoch - use parent's method with hard negative mining integration"""
|
||||
# Add hard negative mining updates before calling parent's method
|
||||
if self.hard_negative_miner:
|
||||
# Set up a hook to update hard negatives during training
|
||||
def mining_hook():
|
||||
if self.hard_negative_miner:
|
||||
self.hard_negative_miner.update_index_if_needed(self.global_step, self.model)
|
||||
|
||||
# Store the hook for use in compute_loss
|
||||
self._mining_hook = mining_hook
|
||||
|
||||
# Use parent's robust training loop
|
||||
return super().train_epoch(dataloader, epoch)
|
||||
|
||||
def evaluate(self, dataloader: DataLoader) -> Dict[str, float]:
|
||||
"""Evaluate the model"""
|
||||
self.model.eval()
|
||||
|
||||
all_query_embeddings = []
|
||||
all_passage_embeddings = []
|
||||
all_labels = []
|
||||
|
||||
with torch.no_grad():
|
||||
with tqdm(dataloader, desc="Evaluating") as tbar:
|
||||
for batch in tbar:
|
||||
try:
|
||||
batch = self._prepare_batch(batch)
|
||||
|
||||
query_outputs = self.model.forward(
|
||||
input_ids=batch['query_input_ids'],
|
||||
attention_mask=batch['query_attention_mask'],
|
||||
return_dense=True,
|
||||
return_dict=True
|
||||
)
|
||||
|
||||
passage_outputs = self.model.forward(
|
||||
input_ids=batch['passage_input_ids'],
|
||||
attention_mask=batch['passage_attention_mask'],
|
||||
return_dense=True,
|
||||
return_dict=True
|
||||
)
|
||||
|
||||
all_query_embeddings.append(query_outputs['dense'].cpu()) # type: ignore[index]
|
||||
all_passage_embeddings.append(passage_outputs['dense'].cpu()) # type: ignore[index]
|
||||
all_labels.append(batch['labels'].cpu())
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in M3 evaluation step: {e}")
|
||||
continue
|
||||
|
||||
if not all_query_embeddings:
|
||||
logger.warning("No valid evaluation batches processed")
|
||||
return {'loss': float('inf'), 'accuracy': 0.0}
|
||||
|
||||
# Concatenate all embeddings
|
||||
all_query_embeddings = torch.cat(all_query_embeddings, dim=0)
|
||||
all_passage_embeddings = torch.cat(all_passage_embeddings, dim=0)
|
||||
all_labels = torch.cat(all_labels, dim=0)
|
||||
|
||||
# Compute metrics
|
||||
try:
|
||||
metrics = compute_retrieval_metrics(
|
||||
all_query_embeddings.numpy(),
|
||||
all_passage_embeddings.numpy(),
|
||||
all_labels.numpy(),
|
||||
k_values=[1, 3, 5, 10]
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error computing retrieval metrics: {e}")
|
||||
metrics = {'loss': float('inf'), 'accuracy': 0.0}
|
||||
|
||||
return metrics
|
||||
|
||||
def save_model(self, save_path: str):
|
||||
"""Save the trained model"""
|
||||
try:
|
||||
os.makedirs(save_path, exist_ok=True)
|
||||
|
||||
# Save model
|
||||
model_to_save = self.model.module if hasattr(self.model, 'module') else self.model
|
||||
model_to_save.save_pretrained(save_path) # type: ignore[attr-defined]
|
||||
|
||||
# Save tokenizer
|
||||
self.tokenizer.save_pretrained(save_path) # type: ignore[attr-defined]
|
||||
|
||||
# Save config
|
||||
self.config.save(os.path.join(save_path, 'training_config.json'))
|
||||
|
||||
logger.info(f"Model saved to {save_path}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving model: {e}")
|
||||
raise
|
||||
|
||||
def __del__(self):
|
||||
"""Cleanup when trainer is destroyed"""
|
||||
if hasattr(self, 'hard_negative_miner') and self.hard_negative_miner:
|
||||
self.hard_negative_miner.stop_async_updates()
|
||||
489
training/reranker_trainer.py
Normal file
489
training/reranker_trainer.py
Normal file
@@ -0,0 +1,489 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.optim import AdamW
|
||||
from typing import Dict, Optional, List, Tuple, Any
|
||||
import logging
|
||||
import numpy as np
|
||||
from torch.nn import functional as F
|
||||
from tqdm import tqdm
|
||||
import os
|
||||
|
||||
from .trainer import BaseTrainer
|
||||
from models.bge_reranker import BGERerankerModel
|
||||
from models.losses import ListwiseCrossEntropyLoss, PairwiseMarginLoss
|
||||
from data.dataset import BGERerankerDataset
|
||||
from config.reranker_config import BGERerankerConfig
|
||||
from evaluation.metrics import compute_reranker_metrics
|
||||
from utils.distributed import is_main_process
|
||||
|
||||
# AMP imports
|
||||
from torch.cuda.amp.autocast_mode import autocast
|
||||
from torch.cuda.amp.grad_scaler import GradScaler
|
||||
|
||||
# Scheduler import
|
||||
try:
|
||||
from transformers import get_linear_schedule_with_warmup
|
||||
except ImportError:
|
||||
get_linear_schedule_with_warmup = None
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BGERerankerTrainer(BaseTrainer):
|
||||
"""
|
||||
Trainer for BGE-Reranker cross-encoder model
|
||||
|
||||
Notes:
|
||||
- All parameters (config, model, tokenizer, optimizer, scheduler, device, etc.) should be passed from config-driven training scripts.
|
||||
- This class does not load config directly; pass config values from your main script for consistency.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: BGERerankerConfig,
|
||||
model: Optional[BGERerankerModel] = None,
|
||||
tokenizer=None,
|
||||
optimizer: Optional[torch.optim.Optimizer] = None,
|
||||
scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
|
||||
device: Optional[torch.device] = None
|
||||
):
|
||||
# Initialize model if not provided
|
||||
if model is None:
|
||||
model = BGERerankerModel(
|
||||
model_name_or_path=config.model_name_or_path,
|
||||
num_labels=config.num_labels,
|
||||
cache_dir=config.cache_dir,
|
||||
use_fp16=config.fp16,
|
||||
gradient_checkpointing=config.gradient_checkpointing
|
||||
)
|
||||
|
||||
# Initialize optimizer if not provided
|
||||
if optimizer is None:
|
||||
optimizer = self._create_optimizer(model, config)
|
||||
|
||||
# Initialize parent class
|
||||
super().__init__(config, model, optimizer, scheduler, device)
|
||||
|
||||
self.tokenizer = tokenizer
|
||||
self.config = config
|
||||
|
||||
# Initialize losses based on config
|
||||
if config.loss_type == "listwise_ce":
|
||||
self.loss_fn = ListwiseCrossEntropyLoss(
|
||||
temperature=getattr(config, 'temperature', 1.0)
|
||||
)
|
||||
elif config.loss_type == "pairwise_margin":
|
||||
self.loss_fn = PairwiseMarginLoss(margin=config.margin)
|
||||
elif config.loss_type == "combined":
|
||||
self.listwise_loss = ListwiseCrossEntropyLoss(
|
||||
temperature=getattr(config, 'temperature', 1.0)
|
||||
)
|
||||
self.pairwise_loss = PairwiseMarginLoss(margin=config.margin)
|
||||
else:
|
||||
self.loss_fn = nn.BCEWithLogitsLoss()
|
||||
|
||||
# Knowledge distillation setup
|
||||
self.use_kd = config.use_knowledge_distillation
|
||||
if self.use_kd:
|
||||
self.kd_loss_fn = nn.KLDivLoss(reduction='batchmean')
|
||||
self.kd_temperature = config.distillation_temperature
|
||||
self.kd_loss_weight = config.distillation_alpha
|
||||
|
||||
# Hard negative configuration
|
||||
self.hard_negative_ratio = config.hard_negative_ratio
|
||||
self.use_denoising = config.use_denoising
|
||||
self.denoise_threshold = getattr(config, 'denoise_confidence_threshold', 0.5)
|
||||
|
||||
def _create_optimizer(self, model: nn.Module, config: BGERerankerConfig) -> torch.optim.Optimizer:
|
||||
"""Create optimizer with proper parameter groups"""
|
||||
# Separate parameters
|
||||
no_decay = ['bias', 'LayerNorm.weight']
|
||||
optimizer_grouped_parameters = [
|
||||
{
|
||||
'params': [p for n, p in model.named_parameters()
|
||||
if not any(nd in n for nd in no_decay)],
|
||||
'weight_decay': config.weight_decay,
|
||||
},
|
||||
{
|
||||
'params': [p for n, p in model.named_parameters()
|
||||
if any(nd in n for nd in no_decay)],
|
||||
'weight_decay': 0.0,
|
||||
}
|
||||
]
|
||||
|
||||
optimizer = AdamW(
|
||||
optimizer_grouped_parameters,
|
||||
lr=config.learning_rate,
|
||||
eps=config.adam_epsilon,
|
||||
betas=(config.adam_beta1, config.adam_beta2)
|
||||
)
|
||||
|
||||
return optimizer
|
||||
|
||||
def create_scheduler(self, num_training_steps: int) -> Optional[object]:
|
||||
"""Create learning rate scheduler"""
|
||||
if get_linear_schedule_with_warmup is None:
|
||||
logger.warning("transformers not available, using constant learning rate")
|
||||
return None
|
||||
|
||||
num_warmup_steps = self.config.warmup_steps or int(num_training_steps * self.config.warmup_ratio)
|
||||
|
||||
self.scheduler = get_linear_schedule_with_warmup(
|
||||
self.optimizer,
|
||||
num_warmup_steps=num_warmup_steps,
|
||||
num_training_steps=num_training_steps
|
||||
)
|
||||
return self.scheduler
|
||||
|
||||
def compute_loss(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
|
||||
"""Compute reranker loss"""
|
||||
# Get model outputs
|
||||
outputs = self.model.forward(
|
||||
input_ids=batch['input_ids'],
|
||||
attention_mask=batch['attention_mask'],
|
||||
return_dict=True
|
||||
)
|
||||
|
||||
if isinstance(outputs, dict):
|
||||
logits = outputs['logits']
|
||||
else:
|
||||
logits = outputs.scores
|
||||
|
||||
labels = batch['labels']
|
||||
|
||||
# Handle padding labels
|
||||
if labels.dim() > 1:
|
||||
# Remove padding (-1 labels)
|
||||
valid_mask = labels != -1
|
||||
if valid_mask.any():
|
||||
logits = logits[valid_mask]
|
||||
labels = labels[valid_mask]
|
||||
else:
|
||||
# If all labels are padding, return zero loss
|
||||
return torch.tensor(0.0, device=logits.device, requires_grad=True)
|
||||
|
||||
# Handle different loss types
|
||||
if self.config.loss_type == "listwise_ce":
|
||||
# Listwise loss expects grouped format
|
||||
try:
|
||||
groups = [self.config.train_group_size] * (len(labels) // self.config.train_group_size)
|
||||
if groups:
|
||||
loss = self.loss_fn(logits.squeeze(-1), labels.float(), groups)
|
||||
else:
|
||||
loss = F.binary_cross_entropy_with_logits(logits.squeeze(-1), labels.float())
|
||||
except Exception as e:
|
||||
# Fallback to BCE with specific error logging
|
||||
logger.warning(f"Listwise CE loss failed: {e}. Falling back to BCE.")
|
||||
loss = F.binary_cross_entropy_with_logits(logits.squeeze(-1), labels.float())
|
||||
|
||||
elif self.config.loss_type == "pairwise_margin":
|
||||
# Extract positive and negative scores
|
||||
pos_mask = labels == 1
|
||||
neg_mask = labels == 0
|
||||
|
||||
if pos_mask.any() and neg_mask.any():
|
||||
groups = [self.config.train_group_size] * (len(labels) // self.config.train_group_size)
|
||||
try:
|
||||
loss = self.loss_fn(logits.squeeze(-1), labels.float(), groups)
|
||||
except Exception as e:
|
||||
# Fallback to BCE with specific error logging
|
||||
logger.warning(f"Pairwise margin loss failed: {e}. Falling back to BCE.")
|
||||
loss = F.binary_cross_entropy_with_logits(logits.squeeze(-1), labels.float())
|
||||
else:
|
||||
# Fallback to BCE if no proper pairs
|
||||
logger.warning("No positive-negative pairs found for pairwise loss. Using BCE.")
|
||||
loss = F.binary_cross_entropy_with_logits(logits.squeeze(-1), labels.float())
|
||||
|
||||
elif self.config.loss_type == "combined":
|
||||
# Combine listwise and pairwise losses
|
||||
try:
|
||||
groups = [self.config.train_group_size] * (len(labels) // self.config.train_group_size)
|
||||
if groups:
|
||||
listwise_loss = self.listwise_loss(logits.squeeze(-1), labels.float(), groups)
|
||||
pairwise_loss = self.pairwise_loss(logits.squeeze(-1), labels.float(), groups)
|
||||
loss = self.config.listwise_weight * listwise_loss + self.config.pairwise_weight * pairwise_loss
|
||||
else:
|
||||
loss = F.binary_cross_entropy_with_logits(logits.squeeze(-1), labels.float())
|
||||
except Exception as e:
|
||||
# Fallback to BCE with specific error logging
|
||||
logger.warning(f"Combined loss computation failed: {e}. Falling back to BCE.")
|
||||
loss = F.binary_cross_entropy_with_logits(logits.squeeze(-1), labels.float())
|
||||
|
||||
else:
|
||||
# Default binary cross-entropy
|
||||
loss = F.binary_cross_entropy_with_logits(logits.squeeze(-1), labels.float())
|
||||
|
||||
# Add knowledge distillation loss if enabled
|
||||
if self.use_kd and 'teacher_scores' in batch:
|
||||
teacher_scores = batch['teacher_scores']
|
||||
student_probs = F.log_softmax(logits / self.kd_temperature, dim=-1)
|
||||
teacher_probs = F.softmax(teacher_scores / self.kd_temperature, dim=-1)
|
||||
|
||||
kd_loss = self.kd_loss_fn(student_probs, teacher_probs) * (self.kd_temperature ** 2)
|
||||
loss = (1 - self.kd_loss_weight) * loss + self.kd_loss_weight * kd_loss
|
||||
|
||||
# Apply denoising if enabled
|
||||
if self.use_denoising and 'teacher_scores' in batch:
|
||||
# Denoise by confidence thresholding
|
||||
confidence_scores = torch.sigmoid(batch['teacher_scores'])
|
||||
confidence_mask = confidence_scores > self.denoise_threshold
|
||||
if confidence_mask.any():
|
||||
loss = loss * confidence_mask.float().mean()
|
||||
|
||||
return loss
|
||||
|
||||
def evaluate_step(self, batch: Dict[str, torch.Tensor]) -> Dict[str, float]:
|
||||
"""Evaluate a single batch"""
|
||||
with torch.no_grad():
|
||||
# Get model outputs
|
||||
outputs = self.model.forward(
|
||||
input_ids=batch['input_ids'],
|
||||
attention_mask=batch['attention_mask'],
|
||||
return_dict=True
|
||||
)
|
||||
|
||||
if isinstance(outputs, dict):
|
||||
logits = outputs['logits']
|
||||
else:
|
||||
logits = outputs.scores
|
||||
|
||||
labels = batch['labels']
|
||||
|
||||
# Handle padding labels
|
||||
if labels.dim() > 1:
|
||||
valid_mask = labels != -1
|
||||
if valid_mask.any():
|
||||
logits = logits[valid_mask]
|
||||
labels = labels[valid_mask]
|
||||
else:
|
||||
return {'loss': 0.0, 'accuracy': 0.0}
|
||||
|
||||
# Compute loss
|
||||
loss = self.compute_loss(batch)
|
||||
|
||||
# Compute metrics based on group size
|
||||
if hasattr(self.config, 'train_group_size') and self.config.train_group_size > 1:
|
||||
# Listwise evaluation
|
||||
groups = [self.config.train_group_size] * (len(labels) // self.config.train_group_size)
|
||||
if groups:
|
||||
try:
|
||||
metrics = compute_reranker_metrics(
|
||||
logits.squeeze(-1).cpu().numpy(),
|
||||
labels.cpu().numpy(),
|
||||
groups=groups,
|
||||
k_values=[1, 3, 5, 10]
|
||||
)
|
||||
metrics['loss'] = loss.item()
|
||||
except Exception as e:
|
||||
logger.warning(f"Error computing reranker metrics: {e}")
|
||||
metrics = {'loss': loss.item(), 'accuracy': 0.0}
|
||||
else:
|
||||
metrics = {'loss': loss.item(), 'accuracy': 0.0}
|
||||
else:
|
||||
# Binary evaluation
|
||||
predictions = (torch.sigmoid(logits.squeeze(-1)) > 0.5).float()
|
||||
accuracy = (predictions == labels.float()).float().mean()
|
||||
|
||||
metrics = {
|
||||
'accuracy': accuracy.item(),
|
||||
'loss': loss.item()
|
||||
}
|
||||
|
||||
# Add AUC if we have both positive and negative examples
|
||||
if len(torch.unique(labels)) > 1:
|
||||
try:
|
||||
from sklearn.metrics import roc_auc_score
|
||||
auc = roc_auc_score(labels.cpu().numpy(), torch.sigmoid(logits.squeeze(-1)).cpu().numpy())
|
||||
metrics['auc'] = float(auc)
|
||||
except Exception as e:
|
||||
logger.warning(f"Error computing AUC: {e}")
|
||||
|
||||
return metrics
|
||||
|
||||
def train_with_hard_negative_mining(
|
||||
self,
|
||||
train_dataset: BGERerankerDataset,
|
||||
eval_dataset: Optional[BGERerankerDataset] = None,
|
||||
retriever_model: Optional[nn.Module] = None
|
||||
):
|
||||
"""
|
||||
Train with periodic hard negative mining from retriever
|
||||
"""
|
||||
if retriever_model is None:
|
||||
logger.warning("No retriever model provided for hard negative mining. Using standard training.")
|
||||
train_dataloader = self._create_dataloader(train_dataset, is_train=True)
|
||||
eval_dataloader = self._create_dataloader(eval_dataset, is_train=False) if eval_dataset else None
|
||||
|
||||
# Only proceed if we have a valid training dataloader
|
||||
if train_dataloader is not None:
|
||||
return super().train(train_dataloader, eval_dataloader)
|
||||
else:
|
||||
logger.error("No valid training dataloader created")
|
||||
return
|
||||
|
||||
# Training loop with hard negative updates
|
||||
for epoch in range(self.config.num_train_epochs):
|
||||
# Mine hard negatives at the beginning of each epoch
|
||||
if epoch > 0:
|
||||
logger.info(f"Mining hard negatives for epoch {epoch + 1}")
|
||||
self._mine_hard_negatives(train_dataset, retriever_model)
|
||||
|
||||
# Create data loader with new negatives
|
||||
train_dataloader = self._create_dataloader(train_dataset, is_train=True)
|
||||
eval_dataloader = self._create_dataloader(eval_dataset, is_train=False) if eval_dataset else None
|
||||
|
||||
# Only proceed if we have a valid training dataloader
|
||||
if train_dataloader is not None:
|
||||
# Train one epoch
|
||||
super().train(train_dataloader, eval_dataloader, num_epochs=1)
|
||||
else:
|
||||
logger.error(f"No valid training dataloader for epoch {epoch + 1}")
|
||||
continue
|
||||
|
||||
# Update current epoch
|
||||
self.current_epoch = epoch
|
||||
|
||||
def _create_dataloader(self, dataset, is_train: bool = True) -> Optional[DataLoader]:
|
||||
"""Create data loader"""
|
||||
if dataset is None:
|
||||
return None
|
||||
|
||||
from data.dataset import collate_reranker_batch
|
||||
|
||||
batch_size = self.config.per_device_train_batch_size if is_train else self.config.per_device_eval_batch_size
|
||||
|
||||
# Setup distributed sampler if needed
|
||||
sampler = None
|
||||
if self.is_distributed:
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
sampler = DistributedSampler(dataset, shuffle=is_train)
|
||||
|
||||
return DataLoader(
|
||||
dataset,
|
||||
batch_size=batch_size,
|
||||
shuffle=is_train and sampler is None,
|
||||
sampler=sampler,
|
||||
num_workers=self.config.dataloader_num_workers,
|
||||
pin_memory=self.config.dataloader_pin_memory,
|
||||
drop_last=is_train,
|
||||
collate_fn=collate_reranker_batch
|
||||
)
|
||||
|
||||
def _mine_hard_negatives(
|
||||
self,
|
||||
dataset: BGERerankerDataset,
|
||||
retriever_model: Any
|
||||
):
|
||||
"""Mine hard negatives using the retriever model and update the dataset's negatives."""
|
||||
logger.info("Starting hard negative mining using retriever model...")
|
||||
num_hard_negatives = getattr(self.config, 'num_hard_negatives', 10)
|
||||
retriever_batch_size = getattr(self.config, 'retriever_batch_size', 32)
|
||||
mined_count = 0
|
||||
|
||||
# Collect all unique candidate passages from the dataset
|
||||
all_passages = set()
|
||||
for item in dataset.data:
|
||||
all_passages.update(item.get('pos', []))
|
||||
all_passages.update(item.get('neg', []))
|
||||
all_passages = list(all_passages)
|
||||
|
||||
if not all_passages:
|
||||
logger.warning("No candidate passages found in dataset for hard negative mining.")
|
||||
return
|
||||
|
||||
# Encode all candidate passages
|
||||
try:
|
||||
passage_embeddings = retriever_model.encode(
|
||||
all_passages,
|
||||
batch_size=retriever_batch_size,
|
||||
convert_to_numpy=True
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error encoding passages for hard negative mining: {e}")
|
||||
return
|
||||
|
||||
# For each query, mine hard negatives
|
||||
for idx in range(len(dataset)):
|
||||
item = dataset.data[idx]
|
||||
query = item['query']
|
||||
positives = set(item.get('pos', []))
|
||||
existing_negs = set(item.get('neg', []))
|
||||
|
||||
try:
|
||||
query_emb = retriever_model.encode([query], batch_size=1, convert_to_numpy=True)[0]
|
||||
|
||||
# Compute similarity
|
||||
scores = np.dot(passage_embeddings, query_emb)
|
||||
|
||||
# Rank passages by similarity
|
||||
ranked_indices = np.argsort(scores)[::-1]
|
||||
|
||||
new_negs = []
|
||||
for i in ranked_indices:
|
||||
passage = all_passages[int(i)] # Convert to int for indexing
|
||||
if passage not in positives and passage not in existing_negs:
|
||||
new_negs.append(passage)
|
||||
if len(new_negs) >= num_hard_negatives:
|
||||
break
|
||||
|
||||
if new_negs:
|
||||
item['neg'] = list(existing_negs) + new_negs
|
||||
mined_count += 1
|
||||
logger.debug(f"Query idx {idx}: Added {len(new_negs)} hard negatives.")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error mining hard negatives for query idx {idx}: {e}")
|
||||
|
||||
logger.info(f"Hard negative mining complete. Updated negatives for {mined_count} queries.")
|
||||
|
||||
def save_model(self, save_path: str):
|
||||
"""Save the trained model"""
|
||||
try:
|
||||
os.makedirs(save_path, exist_ok=True)
|
||||
|
||||
# Save model
|
||||
model_to_save = self.model.module if hasattr(self.model, 'module') else self.model
|
||||
if hasattr(model_to_save, 'save_pretrained'):
|
||||
model_to_save.save_pretrained(save_path) # type: ignore
|
||||
else:
|
||||
# Fallback to torch.save
|
||||
torch.save(model_to_save.state_dict(), os.path.join(save_path, 'pytorch_model.bin')) # type: ignore
|
||||
|
||||
# Save tokenizer if available
|
||||
if self.tokenizer and hasattr(self.tokenizer, 'save_pretrained'):
|
||||
self.tokenizer.save_pretrained(save_path)
|
||||
|
||||
# Save config
|
||||
if hasattr(self.config, 'save'):
|
||||
self.config.save(os.path.join(save_path, 'training_config.json'))
|
||||
|
||||
logger.info(f"Model saved to {save_path}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving model: {e}")
|
||||
raise
|
||||
|
||||
def export_for_inference(self, output_path: str):
|
||||
"""Export model for efficient inference"""
|
||||
if is_main_process():
|
||||
logger.info(f"Exporting model for inference to {output_path}")
|
||||
|
||||
# Save model in inference mode
|
||||
self.model.eval()
|
||||
|
||||
# Save model state
|
||||
torch.save({
|
||||
'model_state_dict': self.model.state_dict(),
|
||||
'config': self.config.to_dict(),
|
||||
'model_type': 'bge_reranker'
|
||||
}, output_path)
|
||||
|
||||
# Also save in HuggingFace format if possible
|
||||
try:
|
||||
hf_path = output_path.replace('.pt', '_hf')
|
||||
self.save_model(hf_path)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to save in HuggingFace format: {e}")
|
||||
|
||||
logger.info(f"Model exported successfully")
|
||||
534
training/trainer.py
Normal file
534
training/trainer.py
Normal file
@@ -0,0 +1,534 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.optim import Optimizer
|
||||
from torch.optim.lr_scheduler import _LRScheduler
|
||||
from typing import Dict, Optional, List, Tuple, Any, Union, cast
|
||||
import os
|
||||
import json
|
||||
import logging
|
||||
from tqdm import tqdm
|
||||
from abc import ABC, abstractmethod
|
||||
import time
|
||||
from datetime import datetime
|
||||
import numpy as np
|
||||
from torch.cuda.amp.autocast_mode import autocast
|
||||
from torch.cuda.amp.grad_scaler import GradScaler
|
||||
from torch.nn.parallel import DataParallel, DistributedDataParallel as DDP
|
||||
from torch.nn.utils.clip_grad import clip_grad_norm_
|
||||
from collections import defaultdict
|
||||
|
||||
from utils.checkpoint import CheckpointManager
|
||||
from utils.logging import MetricsLogger, TensorBoardLogger, ProgressLogger
|
||||
from utils.distributed import (
|
||||
setup_distributed, cleanup_distributed, is_main_process,
|
||||
reduce_dict, synchronize, get_rank, get_world_size
|
||||
)
|
||||
from config.base_config import BaseConfig
|
||||
|
||||
try:
|
||||
from transformers import get_linear_schedule_with_warmup
|
||||
except ImportError:
|
||||
get_linear_schedule_with_warmup = None
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BaseTrainer(ABC):
|
||||
"""
|
||||
Abstract base trainer class for BGE models
|
||||
|
||||
Notes:
|
||||
- All parameters (config, model, optimizer, scheduler, device, etc.) should be passed from config-driven training scripts.
|
||||
- This class does not load config directly; pass config values from your main script for consistency.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: BaseConfig,
|
||||
model: nn.Module,
|
||||
optimizer: Optimizer,
|
||||
scheduler: Optional[_LRScheduler] = None,
|
||||
device: Optional[torch.device] = None
|
||||
):
|
||||
self.config = config
|
||||
self.model = model
|
||||
self.optimizer = optimizer
|
||||
self.scheduler = scheduler
|
||||
|
||||
# Setup device
|
||||
if device is None:
|
||||
self.device = torch.device(config.device)
|
||||
else:
|
||||
self.device = device
|
||||
|
||||
# Move model to device
|
||||
self.model.to(self.device)
|
||||
|
||||
# Setup distributed training if needed
|
||||
self.is_distributed = config.local_rank != -1
|
||||
if self.is_distributed:
|
||||
setup_distributed()
|
||||
self.model = self._setup_distributed_model()
|
||||
|
||||
# Setup mixed precision training
|
||||
self.use_amp = config.fp16 or config.bf16
|
||||
if self.use_amp and config.bf16:
|
||||
# Use bfloat16 for better numerical stability on RTX5090/A100
|
||||
self.amp_dtype = torch.bfloat16
|
||||
else:
|
||||
self.amp_dtype = torch.float16
|
||||
self.scaler = GradScaler() if self.use_amp and not config.bf16 else None
|
||||
|
||||
# Memory optimization settings
|
||||
self.empty_cache_steps = getattr(config, 'empty_cache_steps', 100)
|
||||
self.gradient_checkpointing = getattr(config, 'gradient_checkpointing', False)
|
||||
|
||||
# Setup logging
|
||||
self.metrics_logger = MetricsLogger(config.output_dir) if is_main_process() else None
|
||||
self.tb_logger = TensorBoardLogger(
|
||||
os.path.join(config.logging_dir, 'tensorboard'),
|
||||
enabled=is_main_process()
|
||||
)
|
||||
self.progress_logger = None # Will be initialized in train()
|
||||
|
||||
# Setup checkpointing
|
||||
self.checkpoint_manager = CheckpointManager(
|
||||
output_dir=config.output_dir,
|
||||
save_total_limit=config.save_total_limit
|
||||
) if is_main_process() else None
|
||||
|
||||
# Training state
|
||||
self.global_step = 0
|
||||
self.current_epoch = 0
|
||||
self.best_metric = float('-inf') if config.greater_is_better else float('inf')
|
||||
self.training_history = []
|
||||
|
||||
def _setup_distributed_model(self) -> nn.Module:
|
||||
"""Setup model for distributed training"""
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
|
||||
model = DDP(
|
||||
self.model,
|
||||
device_ids=[self.config.local_rank],
|
||||
output_device=self.config.local_rank,
|
||||
find_unused_parameters=True
|
||||
)
|
||||
return model
|
||||
|
||||
@abstractmethod
|
||||
def compute_loss(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
|
||||
"""Compute loss for a batch - to be implemented by subclasses"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def evaluate_step(self, batch: Dict[str, torch.Tensor]) -> Dict[str, float]:
|
||||
"""Evaluate a single batch - to be implemented by subclasses"""
|
||||
pass
|
||||
|
||||
def train(
|
||||
self,
|
||||
train_dataloader: DataLoader,
|
||||
eval_dataloader: Optional[DataLoader] = None,
|
||||
num_epochs: Optional[int] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Main training loop with robust error handling, config-driven parameters, and progress tracking.
|
||||
"""
|
||||
try:
|
||||
num_epochs = num_epochs or self.config.num_train_epochs
|
||||
total_steps = len(train_dataloader) * num_epochs
|
||||
|
||||
self.progress_logger = ProgressLogger(
|
||||
total_steps=total_steps,
|
||||
log_interval=self.config.logging_steps
|
||||
)
|
||||
|
||||
if is_main_process():
|
||||
logger.info("***** Running training *****")
|
||||
logger.info(f" Num examples = {len(train_dataloader.dataset)}") # type: ignore[attr-defined]
|
||||
logger.info(f" Num epochs = {num_epochs}")
|
||||
logger.info(f" Batch size = {self.config.per_device_train_batch_size}")
|
||||
logger.info(f" Gradient accumulation steps = {self.config.gradient_accumulation_steps}")
|
||||
logger.info(f" Total optimization steps = {total_steps}")
|
||||
logger.info(f" Device = {self.device}")
|
||||
if self.is_distributed:
|
||||
logger.info(f" Distributed training with {get_world_size()} GPUs")
|
||||
|
||||
for epoch in range(self.current_epoch, num_epochs):
|
||||
self.current_epoch = epoch
|
||||
epoch_start_time = time.time()
|
||||
|
||||
train_metrics = self.train_epoch(train_dataloader, epoch)
|
||||
|
||||
# Add metrics to training history
|
||||
epoch_metrics = {
|
||||
'epoch': epoch + 1,
|
||||
'train_metrics': train_metrics,
|
||||
'global_step': self.global_step,
|
||||
'learning_rate': self.optimizer.param_groups[0]['lr'] if self.optimizer else 0,
|
||||
}
|
||||
|
||||
if eval_dataloader is not None and (epoch + 1) % self.config.eval_steps == 0:
|
||||
eval_metrics = self.evaluate(eval_dataloader)
|
||||
epoch_metrics['eval_metrics'] = eval_metrics
|
||||
current_metric = eval_metrics.get(self.config.metric_for_best_model, 0)
|
||||
is_best = self._is_better_metric(current_metric)
|
||||
if is_best:
|
||||
self.best_metric = current_metric
|
||||
if is_main_process() and self.checkpoint_manager:
|
||||
self._save_checkpoint(is_best=is_best, metrics=eval_metrics)
|
||||
if is_main_process():
|
||||
logger.info(f"Epoch {epoch + 1} evaluation: {eval_metrics}")
|
||||
|
||||
# Add to training history
|
||||
self.training_history.append(epoch_metrics)
|
||||
|
||||
if is_main_process() and (epoch + 1) % self.config.save_steps == 0:
|
||||
self._save_checkpoint(metrics=train_metrics)
|
||||
|
||||
epoch_time = time.time() - epoch_start_time
|
||||
if is_main_process():
|
||||
logger.info(f"Epoch {epoch + 1} completed in {epoch_time:.2f} seconds")
|
||||
logger.info(f"Epoch {epoch + 1} metrics: {train_metrics}")
|
||||
|
||||
if is_main_process():
|
||||
self._save_checkpoint(metrics=train_metrics)
|
||||
self._save_training_summary()
|
||||
# Close TensorBoard logger
|
||||
self.tb_logger.close()
|
||||
|
||||
if self.is_distributed:
|
||||
cleanup_distributed()
|
||||
|
||||
return self.training_history # type: ignore[return-value]
|
||||
except Exception as e:
|
||||
logger.error(f"Training failed: {e}")
|
||||
raise
|
||||
|
||||
def train_epoch(
|
||||
self,
|
||||
dataloader: DataLoader,
|
||||
epoch: int
|
||||
) -> Dict[str, float]:
|
||||
"""
|
||||
Train for one epoch with comprehensive error handling and recovery
|
||||
|
||||
Args:
|
||||
dataloader: Training data loader
|
||||
epoch: Current epoch number
|
||||
|
||||
Returns:
|
||||
Dictionary of training metrics for the epoch
|
||||
|
||||
Raises:
|
||||
RuntimeError: If training fails and cannot recover
|
||||
"""
|
||||
self.model.train()
|
||||
epoch_loss = 0.0
|
||||
epoch_metrics = defaultdict(float)
|
||||
num_batches = len(dataloader)
|
||||
|
||||
# Track failed batches for recovery
|
||||
failed_batches = []
|
||||
max_failures = 3
|
||||
|
||||
# Training loop with error recovery
|
||||
progress_bar = tqdm(
|
||||
dataloader,
|
||||
desc=f"Epoch {epoch + 1}",
|
||||
disable=not is_main_process()
|
||||
)
|
||||
|
||||
for step, batch in enumerate(progress_bar):
|
||||
try:
|
||||
# Move batch to device with error handling
|
||||
try:
|
||||
batch = self._prepare_batch(batch)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to prepare batch {step}: {e}")
|
||||
failed_batches.append((step, 'prepare', str(e)))
|
||||
if len(failed_batches) > max_failures:
|
||||
raise RuntimeError(f"Too many batch preparation failures: {len(failed_batches)}")
|
||||
continue
|
||||
|
||||
# Compute loss with error recovery
|
||||
loss: torch.Tensor
|
||||
try:
|
||||
if self.use_amp:
|
||||
with autocast(enabled=True, dtype=self.amp_dtype):
|
||||
loss = self.compute_loss(batch)
|
||||
loss = cast(torch.Tensor, loss)
|
||||
else:
|
||||
loss = self.compute_loss(batch)
|
||||
loss = cast(torch.Tensor, loss)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to compute loss for batch {step}: {e}")
|
||||
failed_batches.append((step, 'forward', str(e)))
|
||||
if len(failed_batches) > max_failures:
|
||||
raise RuntimeError(f"Too many forward pass failures: {len(failed_batches)}")
|
||||
continue
|
||||
|
||||
# Scale loss for gradient accumulation
|
||||
loss = loss / self.config.gradient_accumulation_steps
|
||||
|
||||
# Backward pass with error handling
|
||||
try:
|
||||
if self.scaler:
|
||||
self.scaler.scale(loss).backward() # type: ignore[attr-defined]
|
||||
else:
|
||||
loss.backward() # type: ignore[attr-defined]
|
||||
except Exception as e:
|
||||
logger.warning(f"Backward pass failed for batch {step}: {e}")
|
||||
# Clear gradients and continue
|
||||
self.optimizer.zero_grad()
|
||||
failed_batches.append((step, 'backward', str(e)))
|
||||
if len(failed_batches) > max_failures:
|
||||
raise RuntimeError(f"Too many backward pass failures: {len(failed_batches)}")
|
||||
continue
|
||||
|
||||
# Update weights with error handling
|
||||
if (step + 1) % self.config.gradient_accumulation_steps == 0:
|
||||
try:
|
||||
# Gradient clipping
|
||||
if self.scaler:
|
||||
self.scaler.unscale_(self.optimizer)
|
||||
|
||||
clip_grad_norm_(
|
||||
self.model.parameters(),
|
||||
self.config.max_grad_norm
|
||||
)
|
||||
|
||||
# Optimizer step
|
||||
if self.scaler:
|
||||
self.scaler.step(self.optimizer)
|
||||
self.scaler.update()
|
||||
else:
|
||||
self.optimizer.step()
|
||||
|
||||
# Scheduler step
|
||||
if self.scheduler:
|
||||
self.scheduler.step()
|
||||
|
||||
# Update global step
|
||||
self.global_step += 1
|
||||
|
||||
# Clear cache periodically for memory efficiency
|
||||
if self.empty_cache_steps > 0 and self.global_step % self.empty_cache_steps == 0:
|
||||
self._clear_device_cache()
|
||||
|
||||
# Zero gradients
|
||||
self.optimizer.zero_grad()
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Optimizer update failed at step {step}: {e}")
|
||||
# Reset optimizer state
|
||||
self.optimizer.zero_grad()
|
||||
if self.scaler:
|
||||
self.scaler.update()
|
||||
failed_batches.append((step, 'optimizer', str(e)))
|
||||
continue
|
||||
|
||||
# Logging with error handling
|
||||
if self.global_step % self.config.logging_steps == 0:
|
||||
try:
|
||||
current_loss = loss.item() * self.config.gradient_accumulation_steps
|
||||
epoch_loss += current_loss
|
||||
|
||||
# Update progress bar
|
||||
progress_bar.set_postfix({
|
||||
'loss': f'{current_loss:.4f}',
|
||||
'lr': f'{self.scheduler.get_last_lr()[0]:.2e}' if self.scheduler else 'N/A',
|
||||
'failures': len(failed_batches)
|
||||
})
|
||||
|
||||
# Log to TensorBoard
|
||||
if self.tb_logger:
|
||||
self.tb_logger.log_scalar('train/loss', current_loss, self.global_step)
|
||||
self.tb_logger.log_scalar('train/learning_rate',
|
||||
self.scheduler.get_last_lr()[0] if self.scheduler else self.optimizer.param_groups[0]['lr'],
|
||||
self.global_step)
|
||||
except Exception as e:
|
||||
logger.debug(f"Logging failed: {e}")
|
||||
|
||||
# Checkpoint saving with error handling
|
||||
if self.config.save_steps > 0 and self.global_step % self.config.save_steps == 0 and is_main_process():
|
||||
try:
|
||||
if self.checkpoint_manager:
|
||||
state_dict = {
|
||||
'model_state_dict': self.model.state_dict(),
|
||||
'optimizer_state_dict': self.optimizer.state_dict(),
|
||||
'scheduler_state_dict': self.scheduler.state_dict() if self.scheduler else None,
|
||||
'global_step': self.global_step,
|
||||
'current_epoch': epoch,
|
||||
'config': vars(self.config)
|
||||
}
|
||||
self.checkpoint_manager.save(state_dict, self.global_step)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save checkpoint at step {self.global_step}: {e}")
|
||||
# Continue training even if checkpoint fails
|
||||
|
||||
except Exception as e:
|
||||
# Catch-all for any unexpected errors
|
||||
logger.error(f"Unexpected error in training step {step}: {e}")
|
||||
failed_batches.append((step, 'unknown', str(e)))
|
||||
if len(failed_batches) > max_failures * 2:
|
||||
raise RuntimeError(f"Training failed with too many errors: {len(failed_batches)}")
|
||||
continue
|
||||
|
||||
# Log epoch summary
|
||||
if failed_batches:
|
||||
logger.warning(f"Epoch {epoch + 1} completed with {len(failed_batches)} failed batches")
|
||||
for step, stage, error in failed_batches[:5]: # Show first 5 failures
|
||||
logger.debug(f" Batch {step} failed at {stage}: {error}")
|
||||
|
||||
# Compute epoch metrics
|
||||
epoch_metrics['loss'] = epoch_loss / max(num_batches - len(failed_batches), 1)
|
||||
epoch_metrics['failed_batches'] = len(failed_batches)
|
||||
epoch_metrics['success_rate'] = (num_batches - len(failed_batches)) / num_batches if num_batches > 0 else 0.0
|
||||
|
||||
return epoch_metrics
|
||||
|
||||
def evaluate(self, dataloader: DataLoader) -> Dict[str, float]:
|
||||
"""Evaluate the model"""
|
||||
self.model.eval()
|
||||
|
||||
total_loss = 0
|
||||
num_batches = 0
|
||||
all_metrics = {}
|
||||
|
||||
with torch.no_grad():
|
||||
for batch in tqdm(dataloader, desc="Evaluating", disable=not is_main_process()):
|
||||
batch = self._prepare_batch(batch)
|
||||
|
||||
# Compute metrics
|
||||
metrics = self.evaluate_step(batch)
|
||||
|
||||
# Accumulate metrics
|
||||
for key, value in metrics.items():
|
||||
if key not in all_metrics:
|
||||
all_metrics[key] = 0
|
||||
all_metrics[key] += value
|
||||
|
||||
num_batches += 1
|
||||
|
||||
# Average metrics
|
||||
for key in all_metrics:
|
||||
all_metrics[key] /= num_batches
|
||||
|
||||
# Gather metrics across devices
|
||||
if self.is_distributed:
|
||||
all_metrics = reduce_dict(all_metrics)
|
||||
|
||||
return all_metrics
|
||||
|
||||
def _prepare_batch(self, batch: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Move tensor values in batch dict to the target device, leave non-tensors untouched."""
|
||||
moved = {}
|
||||
for k, v in batch.items():
|
||||
if isinstance(v, torch.Tensor):
|
||||
moved[k] = v.to(self.device, non_blocking=True)
|
||||
else:
|
||||
moved[k] = v # keep lists / strings / ints as-is
|
||||
return moved
|
||||
|
||||
def _clear_device_cache(self):
|
||||
"""Clear device memory cache with support for CUDA and Ascend NPU"""
|
||||
try:
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
else:
|
||||
# Try Ascend support if available
|
||||
try:
|
||||
from utils.ascend_support import clear_ascend_cache
|
||||
clear_ascend_cache()
|
||||
except ImportError:
|
||||
# Fallback to manual NPU check
|
||||
try:
|
||||
npu = getattr(torch, 'npu', None)
|
||||
if npu is not None and hasattr(npu, 'is_available') and npu.is_available():
|
||||
npu.empty_cache()
|
||||
except Exception:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to clear device cache: {e}")
|
||||
|
||||
def _is_better_metric(self, current: float) -> bool:
|
||||
"""Check if current metric is better than best"""
|
||||
if self.config.greater_is_better:
|
||||
return current > self.best_metric
|
||||
else:
|
||||
return current < self.best_metric
|
||||
|
||||
def _save_checkpoint(self, is_best: bool = False, metrics: Optional[Dict[str, float]] = None):
|
||||
"""Save model checkpoint"""
|
||||
if not is_main_process() or not self.checkpoint_manager:
|
||||
return
|
||||
|
||||
checkpoint = {
|
||||
'model_state_dict': self.model.module.state_dict() if self.is_distributed else self.model.state_dict(), # type: ignore[attr-defined]
|
||||
'optimizer_state_dict': self.optimizer.state_dict(),
|
||||
'scheduler_state_dict': self.scheduler.state_dict() if self.scheduler else None,
|
||||
'config': self.config.to_dict(),
|
||||
'global_step': self.global_step,
|
||||
'current_epoch': self.current_epoch,
|
||||
'best_metric': self.best_metric
|
||||
}
|
||||
|
||||
if self.scaler:
|
||||
checkpoint['scaler_state_dict'] = self.scaler.state_dict()
|
||||
|
||||
self.checkpoint_manager.save(
|
||||
checkpoint,
|
||||
step=self.global_step,
|
||||
is_best=is_best,
|
||||
metrics=metrics
|
||||
)
|
||||
|
||||
def _save_training_summary(self):
|
||||
"""Save training summary"""
|
||||
if not is_main_process():
|
||||
return
|
||||
|
||||
summary = {
|
||||
'config': self.config.to_dict(),
|
||||
'final_metrics': self.training_history[-1] if self.training_history else {},
|
||||
'best_metric': self.best_metric,
|
||||
'total_steps': self.global_step,
|
||||
'total_epochs': self.current_epoch,
|
||||
'training_time': datetime.now().isoformat()
|
||||
}
|
||||
|
||||
summary_path = os.path.join(self.config.output_dir, 'training_summary.json')
|
||||
with open(summary_path, 'w') as f:
|
||||
json.dump(summary, f, indent=2)
|
||||
|
||||
def load_checkpoint(self, checkpoint_path: str):
|
||||
"""Load checkpoint"""
|
||||
logger.info(f"Loading checkpoint from {checkpoint_path}")
|
||||
|
||||
checkpoint = torch.load(checkpoint_path, map_location=self.device)
|
||||
|
||||
# Load model state
|
||||
if self.is_distributed:
|
||||
self.model.module.load_state_dict(checkpoint['model_state_dict']) # type: ignore[attr-defined]
|
||||
else:
|
||||
self.model.load_state_dict(checkpoint['model_state_dict']) # type: ignore[attr-defined]
|
||||
|
||||
# Load optimizer state
|
||||
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
||||
|
||||
# Load scheduler state
|
||||
if self.scheduler and checkpoint.get('scheduler_state_dict'):
|
||||
self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
|
||||
|
||||
# Load scaler state
|
||||
if self.scaler and checkpoint.get('scaler_state_dict'):
|
||||
self.scaler.load_state_dict(checkpoint['scaler_state_dict'])
|
||||
|
||||
# Load training state
|
||||
self.global_step = checkpoint.get('global_step', 0)
|
||||
self.current_epoch = checkpoint.get('current_epoch', 0)
|
||||
self.best_metric = checkpoint.get('best_metric', float('-inf'))
|
||||
|
||||
logger.info(f"Resumed from step {self.global_step}, epoch {self.current_epoch}")
|
||||
Reference in New Issue
Block a user