init commit
This commit is contained in:
534
training/trainer.py
Normal file
534
training/trainer.py
Normal file
@@ -0,0 +1,534 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.optim import Optimizer
|
||||
from torch.optim.lr_scheduler import _LRScheduler
|
||||
from typing import Dict, Optional, List, Tuple, Any, Union, cast
|
||||
import os
|
||||
import json
|
||||
import logging
|
||||
from tqdm import tqdm
|
||||
from abc import ABC, abstractmethod
|
||||
import time
|
||||
from datetime import datetime
|
||||
import numpy as np
|
||||
from torch.cuda.amp.autocast_mode import autocast
|
||||
from torch.cuda.amp.grad_scaler import GradScaler
|
||||
from torch.nn.parallel import DataParallel, DistributedDataParallel as DDP
|
||||
from torch.nn.utils.clip_grad import clip_grad_norm_
|
||||
from collections import defaultdict
|
||||
|
||||
from utils.checkpoint import CheckpointManager
|
||||
from utils.logging import MetricsLogger, TensorBoardLogger, ProgressLogger
|
||||
from utils.distributed import (
|
||||
setup_distributed, cleanup_distributed, is_main_process,
|
||||
reduce_dict, synchronize, get_rank, get_world_size
|
||||
)
|
||||
from config.base_config import BaseConfig
|
||||
|
||||
try:
|
||||
from transformers import get_linear_schedule_with_warmup
|
||||
except ImportError:
|
||||
get_linear_schedule_with_warmup = None
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BaseTrainer(ABC):
|
||||
"""
|
||||
Abstract base trainer class for BGE models
|
||||
|
||||
Notes:
|
||||
- All parameters (config, model, optimizer, scheduler, device, etc.) should be passed from config-driven training scripts.
|
||||
- This class does not load config directly; pass config values from your main script for consistency.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: BaseConfig,
|
||||
model: nn.Module,
|
||||
optimizer: Optimizer,
|
||||
scheduler: Optional[_LRScheduler] = None,
|
||||
device: Optional[torch.device] = None
|
||||
):
|
||||
self.config = config
|
||||
self.model = model
|
||||
self.optimizer = optimizer
|
||||
self.scheduler = scheduler
|
||||
|
||||
# Setup device
|
||||
if device is None:
|
||||
self.device = torch.device(config.device)
|
||||
else:
|
||||
self.device = device
|
||||
|
||||
# Move model to device
|
||||
self.model.to(self.device)
|
||||
|
||||
# Setup distributed training if needed
|
||||
self.is_distributed = config.local_rank != -1
|
||||
if self.is_distributed:
|
||||
setup_distributed()
|
||||
self.model = self._setup_distributed_model()
|
||||
|
||||
# Setup mixed precision training
|
||||
self.use_amp = config.fp16 or config.bf16
|
||||
if self.use_amp and config.bf16:
|
||||
# Use bfloat16 for better numerical stability on RTX5090/A100
|
||||
self.amp_dtype = torch.bfloat16
|
||||
else:
|
||||
self.amp_dtype = torch.float16
|
||||
self.scaler = GradScaler() if self.use_amp and not config.bf16 else None
|
||||
|
||||
# Memory optimization settings
|
||||
self.empty_cache_steps = getattr(config, 'empty_cache_steps', 100)
|
||||
self.gradient_checkpointing = getattr(config, 'gradient_checkpointing', False)
|
||||
|
||||
# Setup logging
|
||||
self.metrics_logger = MetricsLogger(config.output_dir) if is_main_process() else None
|
||||
self.tb_logger = TensorBoardLogger(
|
||||
os.path.join(config.logging_dir, 'tensorboard'),
|
||||
enabled=is_main_process()
|
||||
)
|
||||
self.progress_logger = None # Will be initialized in train()
|
||||
|
||||
# Setup checkpointing
|
||||
self.checkpoint_manager = CheckpointManager(
|
||||
output_dir=config.output_dir,
|
||||
save_total_limit=config.save_total_limit
|
||||
) if is_main_process() else None
|
||||
|
||||
# Training state
|
||||
self.global_step = 0
|
||||
self.current_epoch = 0
|
||||
self.best_metric = float('-inf') if config.greater_is_better else float('inf')
|
||||
self.training_history = []
|
||||
|
||||
def _setup_distributed_model(self) -> nn.Module:
|
||||
"""Setup model for distributed training"""
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
|
||||
model = DDP(
|
||||
self.model,
|
||||
device_ids=[self.config.local_rank],
|
||||
output_device=self.config.local_rank,
|
||||
find_unused_parameters=True
|
||||
)
|
||||
return model
|
||||
|
||||
@abstractmethod
|
||||
def compute_loss(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
|
||||
"""Compute loss for a batch - to be implemented by subclasses"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def evaluate_step(self, batch: Dict[str, torch.Tensor]) -> Dict[str, float]:
|
||||
"""Evaluate a single batch - to be implemented by subclasses"""
|
||||
pass
|
||||
|
||||
def train(
|
||||
self,
|
||||
train_dataloader: DataLoader,
|
||||
eval_dataloader: Optional[DataLoader] = None,
|
||||
num_epochs: Optional[int] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Main training loop with robust error handling, config-driven parameters, and progress tracking.
|
||||
"""
|
||||
try:
|
||||
num_epochs = num_epochs or self.config.num_train_epochs
|
||||
total_steps = len(train_dataloader) * num_epochs
|
||||
|
||||
self.progress_logger = ProgressLogger(
|
||||
total_steps=total_steps,
|
||||
log_interval=self.config.logging_steps
|
||||
)
|
||||
|
||||
if is_main_process():
|
||||
logger.info("***** Running training *****")
|
||||
logger.info(f" Num examples = {len(train_dataloader.dataset)}") # type: ignore[attr-defined]
|
||||
logger.info(f" Num epochs = {num_epochs}")
|
||||
logger.info(f" Batch size = {self.config.per_device_train_batch_size}")
|
||||
logger.info(f" Gradient accumulation steps = {self.config.gradient_accumulation_steps}")
|
||||
logger.info(f" Total optimization steps = {total_steps}")
|
||||
logger.info(f" Device = {self.device}")
|
||||
if self.is_distributed:
|
||||
logger.info(f" Distributed training with {get_world_size()} GPUs")
|
||||
|
||||
for epoch in range(self.current_epoch, num_epochs):
|
||||
self.current_epoch = epoch
|
||||
epoch_start_time = time.time()
|
||||
|
||||
train_metrics = self.train_epoch(train_dataloader, epoch)
|
||||
|
||||
# Add metrics to training history
|
||||
epoch_metrics = {
|
||||
'epoch': epoch + 1,
|
||||
'train_metrics': train_metrics,
|
||||
'global_step': self.global_step,
|
||||
'learning_rate': self.optimizer.param_groups[0]['lr'] if self.optimizer else 0,
|
||||
}
|
||||
|
||||
if eval_dataloader is not None and (epoch + 1) % self.config.eval_steps == 0:
|
||||
eval_metrics = self.evaluate(eval_dataloader)
|
||||
epoch_metrics['eval_metrics'] = eval_metrics
|
||||
current_metric = eval_metrics.get(self.config.metric_for_best_model, 0)
|
||||
is_best = self._is_better_metric(current_metric)
|
||||
if is_best:
|
||||
self.best_metric = current_metric
|
||||
if is_main_process() and self.checkpoint_manager:
|
||||
self._save_checkpoint(is_best=is_best, metrics=eval_metrics)
|
||||
if is_main_process():
|
||||
logger.info(f"Epoch {epoch + 1} evaluation: {eval_metrics}")
|
||||
|
||||
# Add to training history
|
||||
self.training_history.append(epoch_metrics)
|
||||
|
||||
if is_main_process() and (epoch + 1) % self.config.save_steps == 0:
|
||||
self._save_checkpoint(metrics=train_metrics)
|
||||
|
||||
epoch_time = time.time() - epoch_start_time
|
||||
if is_main_process():
|
||||
logger.info(f"Epoch {epoch + 1} completed in {epoch_time:.2f} seconds")
|
||||
logger.info(f"Epoch {epoch + 1} metrics: {train_metrics}")
|
||||
|
||||
if is_main_process():
|
||||
self._save_checkpoint(metrics=train_metrics)
|
||||
self._save_training_summary()
|
||||
# Close TensorBoard logger
|
||||
self.tb_logger.close()
|
||||
|
||||
if self.is_distributed:
|
||||
cleanup_distributed()
|
||||
|
||||
return self.training_history # type: ignore[return-value]
|
||||
except Exception as e:
|
||||
logger.error(f"Training failed: {e}")
|
||||
raise
|
||||
|
||||
def train_epoch(
|
||||
self,
|
||||
dataloader: DataLoader,
|
||||
epoch: int
|
||||
) -> Dict[str, float]:
|
||||
"""
|
||||
Train for one epoch with comprehensive error handling and recovery
|
||||
|
||||
Args:
|
||||
dataloader: Training data loader
|
||||
epoch: Current epoch number
|
||||
|
||||
Returns:
|
||||
Dictionary of training metrics for the epoch
|
||||
|
||||
Raises:
|
||||
RuntimeError: If training fails and cannot recover
|
||||
"""
|
||||
self.model.train()
|
||||
epoch_loss = 0.0
|
||||
epoch_metrics = defaultdict(float)
|
||||
num_batches = len(dataloader)
|
||||
|
||||
# Track failed batches for recovery
|
||||
failed_batches = []
|
||||
max_failures = 3
|
||||
|
||||
# Training loop with error recovery
|
||||
progress_bar = tqdm(
|
||||
dataloader,
|
||||
desc=f"Epoch {epoch + 1}",
|
||||
disable=not is_main_process()
|
||||
)
|
||||
|
||||
for step, batch in enumerate(progress_bar):
|
||||
try:
|
||||
# Move batch to device with error handling
|
||||
try:
|
||||
batch = self._prepare_batch(batch)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to prepare batch {step}: {e}")
|
||||
failed_batches.append((step, 'prepare', str(e)))
|
||||
if len(failed_batches) > max_failures:
|
||||
raise RuntimeError(f"Too many batch preparation failures: {len(failed_batches)}")
|
||||
continue
|
||||
|
||||
# Compute loss with error recovery
|
||||
loss: torch.Tensor
|
||||
try:
|
||||
if self.use_amp:
|
||||
with autocast(enabled=True, dtype=self.amp_dtype):
|
||||
loss = self.compute_loss(batch)
|
||||
loss = cast(torch.Tensor, loss)
|
||||
else:
|
||||
loss = self.compute_loss(batch)
|
||||
loss = cast(torch.Tensor, loss)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to compute loss for batch {step}: {e}")
|
||||
failed_batches.append((step, 'forward', str(e)))
|
||||
if len(failed_batches) > max_failures:
|
||||
raise RuntimeError(f"Too many forward pass failures: {len(failed_batches)}")
|
||||
continue
|
||||
|
||||
# Scale loss for gradient accumulation
|
||||
loss = loss / self.config.gradient_accumulation_steps
|
||||
|
||||
# Backward pass with error handling
|
||||
try:
|
||||
if self.scaler:
|
||||
self.scaler.scale(loss).backward() # type: ignore[attr-defined]
|
||||
else:
|
||||
loss.backward() # type: ignore[attr-defined]
|
||||
except Exception as e:
|
||||
logger.warning(f"Backward pass failed for batch {step}: {e}")
|
||||
# Clear gradients and continue
|
||||
self.optimizer.zero_grad()
|
||||
failed_batches.append((step, 'backward', str(e)))
|
||||
if len(failed_batches) > max_failures:
|
||||
raise RuntimeError(f"Too many backward pass failures: {len(failed_batches)}")
|
||||
continue
|
||||
|
||||
# Update weights with error handling
|
||||
if (step + 1) % self.config.gradient_accumulation_steps == 0:
|
||||
try:
|
||||
# Gradient clipping
|
||||
if self.scaler:
|
||||
self.scaler.unscale_(self.optimizer)
|
||||
|
||||
clip_grad_norm_(
|
||||
self.model.parameters(),
|
||||
self.config.max_grad_norm
|
||||
)
|
||||
|
||||
# Optimizer step
|
||||
if self.scaler:
|
||||
self.scaler.step(self.optimizer)
|
||||
self.scaler.update()
|
||||
else:
|
||||
self.optimizer.step()
|
||||
|
||||
# Scheduler step
|
||||
if self.scheduler:
|
||||
self.scheduler.step()
|
||||
|
||||
# Update global step
|
||||
self.global_step += 1
|
||||
|
||||
# Clear cache periodically for memory efficiency
|
||||
if self.empty_cache_steps > 0 and self.global_step % self.empty_cache_steps == 0:
|
||||
self._clear_device_cache()
|
||||
|
||||
# Zero gradients
|
||||
self.optimizer.zero_grad()
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Optimizer update failed at step {step}: {e}")
|
||||
# Reset optimizer state
|
||||
self.optimizer.zero_grad()
|
||||
if self.scaler:
|
||||
self.scaler.update()
|
||||
failed_batches.append((step, 'optimizer', str(e)))
|
||||
continue
|
||||
|
||||
# Logging with error handling
|
||||
if self.global_step % self.config.logging_steps == 0:
|
||||
try:
|
||||
current_loss = loss.item() * self.config.gradient_accumulation_steps
|
||||
epoch_loss += current_loss
|
||||
|
||||
# Update progress bar
|
||||
progress_bar.set_postfix({
|
||||
'loss': f'{current_loss:.4f}',
|
||||
'lr': f'{self.scheduler.get_last_lr()[0]:.2e}' if self.scheduler else 'N/A',
|
||||
'failures': len(failed_batches)
|
||||
})
|
||||
|
||||
# Log to TensorBoard
|
||||
if self.tb_logger:
|
||||
self.tb_logger.log_scalar('train/loss', current_loss, self.global_step)
|
||||
self.tb_logger.log_scalar('train/learning_rate',
|
||||
self.scheduler.get_last_lr()[0] if self.scheduler else self.optimizer.param_groups[0]['lr'],
|
||||
self.global_step)
|
||||
except Exception as e:
|
||||
logger.debug(f"Logging failed: {e}")
|
||||
|
||||
# Checkpoint saving with error handling
|
||||
if self.config.save_steps > 0 and self.global_step % self.config.save_steps == 0 and is_main_process():
|
||||
try:
|
||||
if self.checkpoint_manager:
|
||||
state_dict = {
|
||||
'model_state_dict': self.model.state_dict(),
|
||||
'optimizer_state_dict': self.optimizer.state_dict(),
|
||||
'scheduler_state_dict': self.scheduler.state_dict() if self.scheduler else None,
|
||||
'global_step': self.global_step,
|
||||
'current_epoch': epoch,
|
||||
'config': vars(self.config)
|
||||
}
|
||||
self.checkpoint_manager.save(state_dict, self.global_step)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save checkpoint at step {self.global_step}: {e}")
|
||||
# Continue training even if checkpoint fails
|
||||
|
||||
except Exception as e:
|
||||
# Catch-all for any unexpected errors
|
||||
logger.error(f"Unexpected error in training step {step}: {e}")
|
||||
failed_batches.append((step, 'unknown', str(e)))
|
||||
if len(failed_batches) > max_failures * 2:
|
||||
raise RuntimeError(f"Training failed with too many errors: {len(failed_batches)}")
|
||||
continue
|
||||
|
||||
# Log epoch summary
|
||||
if failed_batches:
|
||||
logger.warning(f"Epoch {epoch + 1} completed with {len(failed_batches)} failed batches")
|
||||
for step, stage, error in failed_batches[:5]: # Show first 5 failures
|
||||
logger.debug(f" Batch {step} failed at {stage}: {error}")
|
||||
|
||||
# Compute epoch metrics
|
||||
epoch_metrics['loss'] = epoch_loss / max(num_batches - len(failed_batches), 1)
|
||||
epoch_metrics['failed_batches'] = len(failed_batches)
|
||||
epoch_metrics['success_rate'] = (num_batches - len(failed_batches)) / num_batches if num_batches > 0 else 0.0
|
||||
|
||||
return epoch_metrics
|
||||
|
||||
def evaluate(self, dataloader: DataLoader) -> Dict[str, float]:
|
||||
"""Evaluate the model"""
|
||||
self.model.eval()
|
||||
|
||||
total_loss = 0
|
||||
num_batches = 0
|
||||
all_metrics = {}
|
||||
|
||||
with torch.no_grad():
|
||||
for batch in tqdm(dataloader, desc="Evaluating", disable=not is_main_process()):
|
||||
batch = self._prepare_batch(batch)
|
||||
|
||||
# Compute metrics
|
||||
metrics = self.evaluate_step(batch)
|
||||
|
||||
# Accumulate metrics
|
||||
for key, value in metrics.items():
|
||||
if key not in all_metrics:
|
||||
all_metrics[key] = 0
|
||||
all_metrics[key] += value
|
||||
|
||||
num_batches += 1
|
||||
|
||||
# Average metrics
|
||||
for key in all_metrics:
|
||||
all_metrics[key] /= num_batches
|
||||
|
||||
# Gather metrics across devices
|
||||
if self.is_distributed:
|
||||
all_metrics = reduce_dict(all_metrics)
|
||||
|
||||
return all_metrics
|
||||
|
||||
def _prepare_batch(self, batch: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Move tensor values in batch dict to the target device, leave non-tensors untouched."""
|
||||
moved = {}
|
||||
for k, v in batch.items():
|
||||
if isinstance(v, torch.Tensor):
|
||||
moved[k] = v.to(self.device, non_blocking=True)
|
||||
else:
|
||||
moved[k] = v # keep lists / strings / ints as-is
|
||||
return moved
|
||||
|
||||
def _clear_device_cache(self):
|
||||
"""Clear device memory cache with support for CUDA and Ascend NPU"""
|
||||
try:
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
else:
|
||||
# Try Ascend support if available
|
||||
try:
|
||||
from utils.ascend_support import clear_ascend_cache
|
||||
clear_ascend_cache()
|
||||
except ImportError:
|
||||
# Fallback to manual NPU check
|
||||
try:
|
||||
npu = getattr(torch, 'npu', None)
|
||||
if npu is not None and hasattr(npu, 'is_available') and npu.is_available():
|
||||
npu.empty_cache()
|
||||
except Exception:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to clear device cache: {e}")
|
||||
|
||||
def _is_better_metric(self, current: float) -> bool:
|
||||
"""Check if current metric is better than best"""
|
||||
if self.config.greater_is_better:
|
||||
return current > self.best_metric
|
||||
else:
|
||||
return current < self.best_metric
|
||||
|
||||
def _save_checkpoint(self, is_best: bool = False, metrics: Optional[Dict[str, float]] = None):
|
||||
"""Save model checkpoint"""
|
||||
if not is_main_process() or not self.checkpoint_manager:
|
||||
return
|
||||
|
||||
checkpoint = {
|
||||
'model_state_dict': self.model.module.state_dict() if self.is_distributed else self.model.state_dict(), # type: ignore[attr-defined]
|
||||
'optimizer_state_dict': self.optimizer.state_dict(),
|
||||
'scheduler_state_dict': self.scheduler.state_dict() if self.scheduler else None,
|
||||
'config': self.config.to_dict(),
|
||||
'global_step': self.global_step,
|
||||
'current_epoch': self.current_epoch,
|
||||
'best_metric': self.best_metric
|
||||
}
|
||||
|
||||
if self.scaler:
|
||||
checkpoint['scaler_state_dict'] = self.scaler.state_dict()
|
||||
|
||||
self.checkpoint_manager.save(
|
||||
checkpoint,
|
||||
step=self.global_step,
|
||||
is_best=is_best,
|
||||
metrics=metrics
|
||||
)
|
||||
|
||||
def _save_training_summary(self):
|
||||
"""Save training summary"""
|
||||
if not is_main_process():
|
||||
return
|
||||
|
||||
summary = {
|
||||
'config': self.config.to_dict(),
|
||||
'final_metrics': self.training_history[-1] if self.training_history else {},
|
||||
'best_metric': self.best_metric,
|
||||
'total_steps': self.global_step,
|
||||
'total_epochs': self.current_epoch,
|
||||
'training_time': datetime.now().isoformat()
|
||||
}
|
||||
|
||||
summary_path = os.path.join(self.config.output_dir, 'training_summary.json')
|
||||
with open(summary_path, 'w') as f:
|
||||
json.dump(summary, f, indent=2)
|
||||
|
||||
def load_checkpoint(self, checkpoint_path: str):
|
||||
"""Load checkpoint"""
|
||||
logger.info(f"Loading checkpoint from {checkpoint_path}")
|
||||
|
||||
checkpoint = torch.load(checkpoint_path, map_location=self.device)
|
||||
|
||||
# Load model state
|
||||
if self.is_distributed:
|
||||
self.model.module.load_state_dict(checkpoint['model_state_dict']) # type: ignore[attr-defined]
|
||||
else:
|
||||
self.model.load_state_dict(checkpoint['model_state_dict']) # type: ignore[attr-defined]
|
||||
|
||||
# Load optimizer state
|
||||
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
||||
|
||||
# Load scheduler state
|
||||
if self.scheduler and checkpoint.get('scheduler_state_dict'):
|
||||
self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
|
||||
|
||||
# Load scaler state
|
||||
if self.scaler and checkpoint.get('scaler_state_dict'):
|
||||
self.scaler.load_state_dict(checkpoint['scaler_state_dict'])
|
||||
|
||||
# Load training state
|
||||
self.global_step = checkpoint.get('global_step', 0)
|
||||
self.current_epoch = checkpoint.get('current_epoch', 0)
|
||||
self.best_metric = checkpoint.get('best_metric', float('-inf'))
|
||||
|
||||
logger.info(f"Resumed from step {self.global_step}, epoch {self.current_epoch}")
|
||||
Reference in New Issue
Block a user