bge_finetune/training/trainer.py
2025-07-22 17:49:16 +08:00

544 lines
22 KiB
Python

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler
from typing import Dict, Optional, List, Tuple, Any, Union, cast
import os
import json
import logging
from tqdm import tqdm
from abc import ABC, abstractmethod
import time
from datetime import datetime
import numpy as np
from torch.cuda.amp.autocast_mode import autocast
from torch.cuda.amp.grad_scaler import GradScaler
from torch.nn.parallel import DataParallel, DistributedDataParallel as DDP
from torch.nn.utils.clip_grad import clip_grad_norm_
from collections import defaultdict
from utils.checkpoint import CheckpointManager
from utils.logging import MetricsLogger, TensorBoardLogger, ProgressLogger
from utils.distributed import (
setup_distributed, cleanup_distributed, is_main_process,
reduce_dict, synchronize, get_rank, get_world_size
)
from config.base_config import BaseConfig
try:
from transformers import get_linear_schedule_with_warmup
except ImportError:
get_linear_schedule_with_warmup = None
logger = logging.getLogger(__name__)
class BaseTrainer(ABC):
"""
Abstract base trainer class for BGE models
Notes:
- All parameters (config, model, optimizer, scheduler, device, etc.) should be passed from config-driven training scripts.
- This class does not load config directly; pass config values from your main script for consistency.
"""
def __init__(
self,
config: BaseConfig,
model: nn.Module,
optimizer: Optimizer,
scheduler: Optional[_LRScheduler] = None,
device: Optional[torch.device] = None
):
self.config = config
self.model = model
self.optimizer = optimizer
self.scheduler = scheduler
# Setup device
if device is None:
self.device = torch.device(config.device)
else:
self.device = device
# Move model to device
self.model.to(self.device)
# Setup distributed training if needed
self.is_distributed = config.local_rank != -1
if self.is_distributed:
setup_distributed()
self.model = self._setup_distributed_model()
# Setup mixed precision training
self.use_amp = config.fp16 or config.bf16
if self.use_amp and config.bf16:
# Use bfloat16 for better numerical stability on RTX5090/A100
self.amp_dtype = torch.bfloat16
else:
self.amp_dtype = torch.float16
self.scaler = GradScaler() if self.use_amp and not config.bf16 else None
# Memory optimization settings
self.empty_cache_steps = getattr(config, 'empty_cache_steps', 100)
self.gradient_checkpointing = getattr(config, 'gradient_checkpointing', False)
# Setup logging
self.metrics_logger = MetricsLogger(config.output_dir) if is_main_process() else None
self.tb_logger = TensorBoardLogger(
os.path.join(config.logging_dir, 'tensorboard'),
enabled=is_main_process()
)
self.progress_logger = None # Will be initialized in train()
# Setup checkpointing
self.checkpoint_manager = CheckpointManager(
output_dir=config.output_dir,
save_total_limit=config.save_total_limit
) if is_main_process() else None
# Training state
self.global_step = 0
self.current_epoch = 0
self.best_metric = float('-inf') if config.greater_is_better else float('inf')
self.training_history = []
def _setup_distributed_model(self) -> nn.Module:
"""Setup model for distributed training"""
from torch.nn.parallel import DistributedDataParallel as DDP
model = DDP(
self.model,
device_ids=[self.config.local_rank],
output_device=self.config.local_rank,
find_unused_parameters=True
)
return model
@abstractmethod
def compute_loss(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
"""Compute loss for a batch - to be implemented by subclasses"""
pass
@abstractmethod
def evaluate_step(self, batch: Dict[str, torch.Tensor]) -> Dict[str, float]:
"""Evaluate a single batch - to be implemented by subclasses"""
pass
def train(
self,
train_dataloader: DataLoader,
eval_dataloader: Optional[DataLoader] = None,
num_epochs: Optional[int] = None
) -> Dict[str, Any]:
"""
Main training loop with robust error handling, config-driven parameters, and progress tracking.
"""
try:
num_epochs = num_epochs or self.config.num_train_epochs
total_steps = len(train_dataloader) * num_epochs
self.progress_logger = ProgressLogger(
total_steps=total_steps,
log_interval=self.config.logging_steps
)
if is_main_process():
logger.info("***** Running training *****")
logger.info(f" Num examples = {len(train_dataloader.dataset)}") # type: ignore[attr-defined]
logger.info(f" Num epochs = {num_epochs}")
logger.info(f" Batch size = {self.config.per_device_train_batch_size}")
logger.info(f" Gradient accumulation steps = {self.config.gradient_accumulation_steps}")
logger.info(f" Total optimization steps = {total_steps}")
logger.info(f" Device = {self.device}")
if self.is_distributed:
logger.info(f" Distributed training with {get_world_size()} GPUs")
for epoch in range(self.current_epoch, num_epochs):
self.current_epoch = epoch
epoch_start_time = time.time()
train_metrics = self.train_epoch(train_dataloader, epoch)
# Add metrics to training history
epoch_metrics = {
'epoch': epoch + 1,
'train_metrics': train_metrics,
'global_step': self.global_step,
'learning_rate': self.optimizer.param_groups[0]['lr'] if self.optimizer else 0,
}
if eval_dataloader is not None and (epoch + 1) % self.config.eval_steps == 0:
eval_metrics = self.evaluate(eval_dataloader)
epoch_metrics['eval_metrics'] = eval_metrics
current_metric = eval_metrics.get(self.config.metric_for_best_model, 0)
is_best = self._is_better_metric(current_metric)
if is_best:
self.best_metric = current_metric
if is_main_process() and self.checkpoint_manager:
self._save_checkpoint(is_best=is_best, metrics=eval_metrics)
if is_main_process():
logger.info(f"Epoch {epoch + 1} evaluation: {eval_metrics}")
# Add to training history
self.training_history.append(epoch_metrics)
# Save checkpoint at end of epoch (if save_steps allows and we've done training)
if (is_main_process() and
self.config.save_steps > 0 and
self.global_step > 0 and
(epoch + 1) % self.config.save_steps == 0):
self._save_checkpoint(metrics=train_metrics)
epoch_time = time.time() - epoch_start_time
if is_main_process():
logger.info(f"Epoch {epoch + 1} completed in {epoch_time:.2f} seconds")
logger.info(f"Epoch {epoch + 1} metrics: {train_metrics}")
if is_main_process():
# Only save final checkpoint if we've actually trained (global_step > 0)
if self.global_step > 0:
self._save_checkpoint(metrics=train_metrics)
self._save_training_summary()
# Close TensorBoard logger
self.tb_logger.close()
if self.is_distributed:
cleanup_distributed()
return self.training_history # type: ignore[return-value]
except Exception as e:
logger.error(f"Training failed: {e}")
raise
def train_epoch(
self,
dataloader: DataLoader,
epoch: int
) -> Dict[str, float]:
"""
Train for one epoch with comprehensive error handling and recovery
Args:
dataloader: Training data loader
epoch: Current epoch number
Returns:
Dictionary of training metrics for the epoch
Raises:
RuntimeError: If training fails and cannot recover
"""
self.model.train()
epoch_loss = 0.0
epoch_metrics = defaultdict(float)
num_batches = len(dataloader)
# Track failed batches for recovery
failed_batches = []
max_failures = 3
# Training loop with error recovery
progress_bar = tqdm(
dataloader,
desc=f"Epoch {epoch + 1}",
disable=not is_main_process()
)
for step, batch in enumerate(progress_bar):
try:
# Move batch to device with error handling
try:
batch = self._prepare_batch(batch)
except Exception as e:
logger.warning(f"Failed to prepare batch {step}: {e}")
failed_batches.append((step, 'prepare', str(e)))
if len(failed_batches) > max_failures:
raise RuntimeError(f"Too many batch preparation failures: {len(failed_batches)}")
continue
# Compute loss with error recovery
loss: torch.Tensor
try:
if self.use_amp:
with autocast(enabled=True, dtype=self.amp_dtype):
loss = self.compute_loss(batch)
loss = cast(torch.Tensor, loss)
else:
loss = self.compute_loss(batch)
loss = cast(torch.Tensor, loss)
except Exception as e:
logger.warning(f"Failed to compute loss for batch {step}: {e}")
failed_batches.append((step, 'forward', str(e)))
if len(failed_batches) > max_failures:
raise RuntimeError(f"Too many forward pass failures: {len(failed_batches)}")
continue
# Scale loss for gradient accumulation
loss = loss / self.config.gradient_accumulation_steps
# Backward pass with error handling
try:
if self.scaler:
self.scaler.scale(loss).backward() # type: ignore[attr-defined]
else:
loss.backward() # type: ignore[attr-defined]
except Exception as e:
logger.warning(f"Backward pass failed for batch {step}: {e}")
# Clear gradients and continue
self.optimizer.zero_grad()
failed_batches.append((step, 'backward', str(e)))
if len(failed_batches) > max_failures:
raise RuntimeError(f"Too many backward pass failures: {len(failed_batches)}")
continue
# Update weights with error handling
if (step + 1) % self.config.gradient_accumulation_steps == 0:
try:
# Gradient clipping
if self.scaler:
self.scaler.unscale_(self.optimizer)
clip_grad_norm_(
self.model.parameters(),
self.config.max_grad_norm
)
# Optimizer step
if self.scaler:
self.scaler.step(self.optimizer)
self.scaler.update()
else:
self.optimizer.step()
# Scheduler step
if self.scheduler:
self.scheduler.step()
# Update global step
self.global_step += 1
# Clear cache periodically for memory efficiency
if self.empty_cache_steps > 0 and self.global_step % self.empty_cache_steps == 0:
self._clear_device_cache()
# Zero gradients
self.optimizer.zero_grad()
except Exception as e:
logger.warning(f"Optimizer update failed at step {step}: {e}")
# Reset optimizer state
self.optimizer.zero_grad()
if self.scaler:
self.scaler.update()
failed_batches.append((step, 'optimizer', str(e)))
continue
# Logging with error handling
if self.global_step % self.config.logging_steps == 0:
try:
current_loss = loss.item() * self.config.gradient_accumulation_steps
epoch_loss += current_loss
# Update progress bar
progress_bar.set_postfix({
'loss': f'{current_loss:.4f}',
'lr': f'{self.scheduler.get_last_lr()[0]:.2e}' if self.scheduler else 'N/A',
'failures': len(failed_batches)
})
# Log to TensorBoard
if self.tb_logger:
self.tb_logger.log_scalar('train/loss', current_loss, self.global_step)
self.tb_logger.log_scalar('train/learning_rate',
self.scheduler.get_last_lr()[0] if self.scheduler else self.optimizer.param_groups[0]['lr'],
self.global_step)
except Exception as e:
logger.debug(f"Logging failed: {e}")
# Checkpoint saving with error handling - avoid saving at step 0
if (self.config.save_steps > 0 and
self.global_step > 0 and # Don't save at step 0
self.global_step % self.config.save_steps == 0 and
is_main_process()):
try:
if self.checkpoint_manager:
state_dict = {
'model_state_dict': self.model.state_dict(),
'optimizer_state_dict': self.optimizer.state_dict(),
'scheduler_state_dict': self.scheduler.state_dict() if self.scheduler else None,
'global_step': self.global_step,
'current_epoch': epoch,
'config': vars(self.config)
}
self.checkpoint_manager.save(state_dict, self.global_step)
except Exception as e:
logger.error(f"Failed to save checkpoint at step {self.global_step}: {e}")
# Continue training even if checkpoint fails
except Exception as e:
# Catch-all for any unexpected errors
logger.error(f"Unexpected error in training step {step}: {e}")
failed_batches.append((step, 'unknown', str(e)))
if len(failed_batches) > max_failures * 2:
raise RuntimeError(f"Training failed with too many errors: {len(failed_batches)}")
continue
# Log epoch summary
if failed_batches:
logger.warning(f"Epoch {epoch + 1} completed with {len(failed_batches)} failed batches")
for step, stage, error in failed_batches[:5]: # Show first 5 failures
logger.debug(f" Batch {step} failed at {stage}: {error}")
# Compute epoch metrics
epoch_metrics['loss'] = epoch_loss / max(num_batches - len(failed_batches), 1)
epoch_metrics['failed_batches'] = len(failed_batches)
epoch_metrics['success_rate'] = (num_batches - len(failed_batches)) / num_batches if num_batches > 0 else 0.0
return epoch_metrics
def evaluate(self, dataloader: DataLoader) -> Dict[str, float]:
"""Evaluate the model"""
self.model.eval()
total_loss = 0
num_batches = 0
all_metrics = {}
with torch.no_grad():
for batch in tqdm(dataloader, desc="Evaluating", disable=not is_main_process()):
batch = self._prepare_batch(batch)
# Compute metrics
metrics = self.evaluate_step(batch)
# Accumulate metrics
for key, value in metrics.items():
if key not in all_metrics:
all_metrics[key] = 0
all_metrics[key] += value
num_batches += 1
# Average metrics
for key in all_metrics:
all_metrics[key] /= num_batches
# Gather metrics across devices
if self.is_distributed:
all_metrics = reduce_dict(all_metrics)
return all_metrics
def _prepare_batch(self, batch: Dict[str, Any]) -> Dict[str, Any]:
"""Move tensor values in batch dict to the target device, leave non-tensors untouched."""
moved = {}
for k, v in batch.items():
if isinstance(v, torch.Tensor):
moved[k] = v.to(self.device, non_blocking=True)
else:
moved[k] = v # keep lists / strings / ints as-is
return moved
def _clear_device_cache(self):
"""Clear device memory cache with support for CUDA and Ascend NPU"""
try:
if torch.cuda.is_available():
torch.cuda.empty_cache()
else:
# Try Ascend support if available
try:
from utils.ascend_support import clear_ascend_cache
clear_ascend_cache()
except ImportError:
# Fallback to manual NPU check
try:
npu = getattr(torch, 'npu', None)
if npu is not None and hasattr(npu, 'is_available') and npu.is_available():
npu.empty_cache()
except Exception:
pass
except Exception as e:
logger.debug(f"Failed to clear device cache: {e}")
def _is_better_metric(self, current: float) -> bool:
"""Check if current metric is better than best"""
if self.config.greater_is_better:
return current > self.best_metric
else:
return current < self.best_metric
def _save_checkpoint(self, is_best: bool = False, metrics: Optional[Dict[str, float]] = None):
"""Save model checkpoint"""
if not is_main_process() or not self.checkpoint_manager:
return
checkpoint = {
'model_state_dict': self.model.module.state_dict() if self.is_distributed else self.model.state_dict(), # type: ignore[attr-defined]
'optimizer_state_dict': self.optimizer.state_dict(),
'scheduler_state_dict': self.scheduler.state_dict() if self.scheduler else None,
'config': self.config.to_dict(),
'global_step': self.global_step,
'current_epoch': self.current_epoch,
'best_metric': self.best_metric
}
if self.scaler:
checkpoint['scaler_state_dict'] = self.scaler.state_dict()
self.checkpoint_manager.save(
checkpoint,
step=self.global_step,
is_best=is_best,
metrics=metrics
)
def _save_training_summary(self):
"""Save training summary"""
if not is_main_process():
return
summary = {
'config': self.config.to_dict(),
'final_metrics': self.training_history[-1] if self.training_history else {},
'best_metric': self.best_metric,
'total_steps': self.global_step,
'total_epochs': self.current_epoch,
'training_time': datetime.now().isoformat()
}
summary_path = os.path.join(self.config.output_dir, 'training_summary.json')
with open(summary_path, 'w') as f:
json.dump(summary, f, indent=2)
def load_checkpoint(self, checkpoint_path: str):
"""Load checkpoint"""
logger.info(f"Loading checkpoint from {checkpoint_path}")
checkpoint = torch.load(checkpoint_path, map_location=self.device)
# Load model state
if self.is_distributed:
self.model.module.load_state_dict(checkpoint['model_state_dict']) # type: ignore[attr-defined]
else:
self.model.load_state_dict(checkpoint['model_state_dict']) # type: ignore[attr-defined]
# Load optimizer state
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
# Load scheduler state
if self.scheduler and checkpoint.get('scheduler_state_dict'):
self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
# Load scaler state
if self.scaler and checkpoint.get('scaler_state_dict'):
self.scaler.load_state_dict(checkpoint['scaler_state_dict'])
# Load training state
self.global_step = checkpoint.get('global_step', 0)
self.current_epoch = checkpoint.get('current_epoch', 0)
self.best_metric = checkpoint.get('best_metric', float('-inf'))
logger.info(f"Resumed from step {self.global_step}, epoch {self.current_epoch}")