535 lines
22 KiB
Python
535 lines
22 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
from torch.utils.data import DataLoader
|
|
from torch.optim import Optimizer
|
|
from torch.optim.lr_scheduler import _LRScheduler
|
|
from typing import Dict, Optional, List, Tuple, Any, Union, cast
|
|
import os
|
|
import json
|
|
import logging
|
|
from tqdm import tqdm
|
|
from abc import ABC, abstractmethod
|
|
import time
|
|
from datetime import datetime
|
|
import numpy as np
|
|
from torch.cuda.amp.autocast_mode import autocast
|
|
from torch.cuda.amp.grad_scaler import GradScaler
|
|
from torch.nn.parallel import DataParallel, DistributedDataParallel as DDP
|
|
from torch.nn.utils.clip_grad import clip_grad_norm_
|
|
from collections import defaultdict
|
|
|
|
from utils.checkpoint import CheckpointManager
|
|
from utils.logging import MetricsLogger, TensorBoardLogger, ProgressLogger
|
|
from utils.distributed import (
|
|
setup_distributed, cleanup_distributed, is_main_process,
|
|
reduce_dict, synchronize, get_rank, get_world_size
|
|
)
|
|
from config.base_config import BaseConfig
|
|
|
|
try:
|
|
from transformers import get_linear_schedule_with_warmup
|
|
except ImportError:
|
|
get_linear_schedule_with_warmup = None
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class BaseTrainer(ABC):
|
|
"""
|
|
Abstract base trainer class for BGE models
|
|
|
|
Notes:
|
|
- All parameters (config, model, optimizer, scheduler, device, etc.) should be passed from config-driven training scripts.
|
|
- This class does not load config directly; pass config values from your main script for consistency.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
config: BaseConfig,
|
|
model: nn.Module,
|
|
optimizer: Optimizer,
|
|
scheduler: Optional[_LRScheduler] = None,
|
|
device: Optional[torch.device] = None
|
|
):
|
|
self.config = config
|
|
self.model = model
|
|
self.optimizer = optimizer
|
|
self.scheduler = scheduler
|
|
|
|
# Setup device
|
|
if device is None:
|
|
self.device = torch.device(config.device)
|
|
else:
|
|
self.device = device
|
|
|
|
# Move model to device
|
|
self.model.to(self.device)
|
|
|
|
# Setup distributed training if needed
|
|
self.is_distributed = config.local_rank != -1
|
|
if self.is_distributed:
|
|
setup_distributed()
|
|
self.model = self._setup_distributed_model()
|
|
|
|
# Setup mixed precision training
|
|
self.use_amp = config.fp16 or config.bf16
|
|
if self.use_amp and config.bf16:
|
|
# Use bfloat16 for better numerical stability on RTX5090/A100
|
|
self.amp_dtype = torch.bfloat16
|
|
else:
|
|
self.amp_dtype = torch.float16
|
|
self.scaler = GradScaler() if self.use_amp and not config.bf16 else None
|
|
|
|
# Memory optimization settings
|
|
self.empty_cache_steps = getattr(config, 'empty_cache_steps', 100)
|
|
self.gradient_checkpointing = getattr(config, 'gradient_checkpointing', False)
|
|
|
|
# Setup logging
|
|
self.metrics_logger = MetricsLogger(config.output_dir) if is_main_process() else None
|
|
self.tb_logger = TensorBoardLogger(
|
|
os.path.join(config.logging_dir, 'tensorboard'),
|
|
enabled=is_main_process()
|
|
)
|
|
self.progress_logger = None # Will be initialized in train()
|
|
|
|
# Setup checkpointing
|
|
self.checkpoint_manager = CheckpointManager(
|
|
output_dir=config.output_dir,
|
|
save_total_limit=config.save_total_limit
|
|
) if is_main_process() else None
|
|
|
|
# Training state
|
|
self.global_step = 0
|
|
self.current_epoch = 0
|
|
self.best_metric = float('-inf') if config.greater_is_better else float('inf')
|
|
self.training_history = []
|
|
|
|
def _setup_distributed_model(self) -> nn.Module:
|
|
"""Setup model for distributed training"""
|
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
|
|
|
model = DDP(
|
|
self.model,
|
|
device_ids=[self.config.local_rank],
|
|
output_device=self.config.local_rank,
|
|
find_unused_parameters=True
|
|
)
|
|
return model
|
|
|
|
@abstractmethod
|
|
def compute_loss(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
|
|
"""Compute loss for a batch - to be implemented by subclasses"""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def evaluate_step(self, batch: Dict[str, torch.Tensor]) -> Dict[str, float]:
|
|
"""Evaluate a single batch - to be implemented by subclasses"""
|
|
pass
|
|
|
|
def train(
|
|
self,
|
|
train_dataloader: DataLoader,
|
|
eval_dataloader: Optional[DataLoader] = None,
|
|
num_epochs: Optional[int] = None
|
|
) -> Dict[str, Any]:
|
|
"""
|
|
Main training loop with robust error handling, config-driven parameters, and progress tracking.
|
|
"""
|
|
try:
|
|
num_epochs = num_epochs or self.config.num_train_epochs
|
|
total_steps = len(train_dataloader) * num_epochs
|
|
|
|
self.progress_logger = ProgressLogger(
|
|
total_steps=total_steps,
|
|
log_interval=self.config.logging_steps
|
|
)
|
|
|
|
if is_main_process():
|
|
logger.info("***** Running training *****")
|
|
logger.info(f" Num examples = {len(train_dataloader.dataset)}") # type: ignore[attr-defined]
|
|
logger.info(f" Num epochs = {num_epochs}")
|
|
logger.info(f" Batch size = {self.config.per_device_train_batch_size}")
|
|
logger.info(f" Gradient accumulation steps = {self.config.gradient_accumulation_steps}")
|
|
logger.info(f" Total optimization steps = {total_steps}")
|
|
logger.info(f" Device = {self.device}")
|
|
if self.is_distributed:
|
|
logger.info(f" Distributed training with {get_world_size()} GPUs")
|
|
|
|
for epoch in range(self.current_epoch, num_epochs):
|
|
self.current_epoch = epoch
|
|
epoch_start_time = time.time()
|
|
|
|
train_metrics = self.train_epoch(train_dataloader, epoch)
|
|
|
|
# Add metrics to training history
|
|
epoch_metrics = {
|
|
'epoch': epoch + 1,
|
|
'train_metrics': train_metrics,
|
|
'global_step': self.global_step,
|
|
'learning_rate': self.optimizer.param_groups[0]['lr'] if self.optimizer else 0,
|
|
}
|
|
|
|
if eval_dataloader is not None and (epoch + 1) % self.config.eval_steps == 0:
|
|
eval_metrics = self.evaluate(eval_dataloader)
|
|
epoch_metrics['eval_metrics'] = eval_metrics
|
|
current_metric = eval_metrics.get(self.config.metric_for_best_model, 0)
|
|
is_best = self._is_better_metric(current_metric)
|
|
if is_best:
|
|
self.best_metric = current_metric
|
|
if is_main_process() and self.checkpoint_manager:
|
|
self._save_checkpoint(is_best=is_best, metrics=eval_metrics)
|
|
if is_main_process():
|
|
logger.info(f"Epoch {epoch + 1} evaluation: {eval_metrics}")
|
|
|
|
# Add to training history
|
|
self.training_history.append(epoch_metrics)
|
|
|
|
if is_main_process() and (epoch + 1) % self.config.save_steps == 0:
|
|
self._save_checkpoint(metrics=train_metrics)
|
|
|
|
epoch_time = time.time() - epoch_start_time
|
|
if is_main_process():
|
|
logger.info(f"Epoch {epoch + 1} completed in {epoch_time:.2f} seconds")
|
|
logger.info(f"Epoch {epoch + 1} metrics: {train_metrics}")
|
|
|
|
if is_main_process():
|
|
self._save_checkpoint(metrics=train_metrics)
|
|
self._save_training_summary()
|
|
# Close TensorBoard logger
|
|
self.tb_logger.close()
|
|
|
|
if self.is_distributed:
|
|
cleanup_distributed()
|
|
|
|
return self.training_history # type: ignore[return-value]
|
|
except Exception as e:
|
|
logger.error(f"Training failed: {e}")
|
|
raise
|
|
|
|
def train_epoch(
|
|
self,
|
|
dataloader: DataLoader,
|
|
epoch: int
|
|
) -> Dict[str, float]:
|
|
"""
|
|
Train for one epoch with comprehensive error handling and recovery
|
|
|
|
Args:
|
|
dataloader: Training data loader
|
|
epoch: Current epoch number
|
|
|
|
Returns:
|
|
Dictionary of training metrics for the epoch
|
|
|
|
Raises:
|
|
RuntimeError: If training fails and cannot recover
|
|
"""
|
|
self.model.train()
|
|
epoch_loss = 0.0
|
|
epoch_metrics = defaultdict(float)
|
|
num_batches = len(dataloader)
|
|
|
|
# Track failed batches for recovery
|
|
failed_batches = []
|
|
max_failures = 3
|
|
|
|
# Training loop with error recovery
|
|
progress_bar = tqdm(
|
|
dataloader,
|
|
desc=f"Epoch {epoch + 1}",
|
|
disable=not is_main_process()
|
|
)
|
|
|
|
for step, batch in enumerate(progress_bar):
|
|
try:
|
|
# Move batch to device with error handling
|
|
try:
|
|
batch = self._prepare_batch(batch)
|
|
except Exception as e:
|
|
logger.warning(f"Failed to prepare batch {step}: {e}")
|
|
failed_batches.append((step, 'prepare', str(e)))
|
|
if len(failed_batches) > max_failures:
|
|
raise RuntimeError(f"Too many batch preparation failures: {len(failed_batches)}")
|
|
continue
|
|
|
|
# Compute loss with error recovery
|
|
loss: torch.Tensor
|
|
try:
|
|
if self.use_amp:
|
|
with autocast(enabled=True, dtype=self.amp_dtype):
|
|
loss = self.compute_loss(batch)
|
|
loss = cast(torch.Tensor, loss)
|
|
else:
|
|
loss = self.compute_loss(batch)
|
|
loss = cast(torch.Tensor, loss)
|
|
except Exception as e:
|
|
logger.warning(f"Failed to compute loss for batch {step}: {e}")
|
|
failed_batches.append((step, 'forward', str(e)))
|
|
if len(failed_batches) > max_failures:
|
|
raise RuntimeError(f"Too many forward pass failures: {len(failed_batches)}")
|
|
continue
|
|
|
|
# Scale loss for gradient accumulation
|
|
loss = loss / self.config.gradient_accumulation_steps
|
|
|
|
# Backward pass with error handling
|
|
try:
|
|
if self.scaler:
|
|
self.scaler.scale(loss).backward() # type: ignore[attr-defined]
|
|
else:
|
|
loss.backward() # type: ignore[attr-defined]
|
|
except Exception as e:
|
|
logger.warning(f"Backward pass failed for batch {step}: {e}")
|
|
# Clear gradients and continue
|
|
self.optimizer.zero_grad()
|
|
failed_batches.append((step, 'backward', str(e)))
|
|
if len(failed_batches) > max_failures:
|
|
raise RuntimeError(f"Too many backward pass failures: {len(failed_batches)}")
|
|
continue
|
|
|
|
# Update weights with error handling
|
|
if (step + 1) % self.config.gradient_accumulation_steps == 0:
|
|
try:
|
|
# Gradient clipping
|
|
if self.scaler:
|
|
self.scaler.unscale_(self.optimizer)
|
|
|
|
clip_grad_norm_(
|
|
self.model.parameters(),
|
|
self.config.max_grad_norm
|
|
)
|
|
|
|
# Optimizer step
|
|
if self.scaler:
|
|
self.scaler.step(self.optimizer)
|
|
self.scaler.update()
|
|
else:
|
|
self.optimizer.step()
|
|
|
|
# Scheduler step
|
|
if self.scheduler:
|
|
self.scheduler.step()
|
|
|
|
# Update global step
|
|
self.global_step += 1
|
|
|
|
# Clear cache periodically for memory efficiency
|
|
if self.empty_cache_steps > 0 and self.global_step % self.empty_cache_steps == 0:
|
|
self._clear_device_cache()
|
|
|
|
# Zero gradients
|
|
self.optimizer.zero_grad()
|
|
|
|
except Exception as e:
|
|
logger.warning(f"Optimizer update failed at step {step}: {e}")
|
|
# Reset optimizer state
|
|
self.optimizer.zero_grad()
|
|
if self.scaler:
|
|
self.scaler.update()
|
|
failed_batches.append((step, 'optimizer', str(e)))
|
|
continue
|
|
|
|
# Logging with error handling
|
|
if self.global_step % self.config.logging_steps == 0:
|
|
try:
|
|
current_loss = loss.item() * self.config.gradient_accumulation_steps
|
|
epoch_loss += current_loss
|
|
|
|
# Update progress bar
|
|
progress_bar.set_postfix({
|
|
'loss': f'{current_loss:.4f}',
|
|
'lr': f'{self.scheduler.get_last_lr()[0]:.2e}' if self.scheduler else 'N/A',
|
|
'failures': len(failed_batches)
|
|
})
|
|
|
|
# Log to TensorBoard
|
|
if self.tb_logger:
|
|
self.tb_logger.log_scalar('train/loss', current_loss, self.global_step)
|
|
self.tb_logger.log_scalar('train/learning_rate',
|
|
self.scheduler.get_last_lr()[0] if self.scheduler else self.optimizer.param_groups[0]['lr'],
|
|
self.global_step)
|
|
except Exception as e:
|
|
logger.debug(f"Logging failed: {e}")
|
|
|
|
# Checkpoint saving with error handling
|
|
if self.config.save_steps > 0 and self.global_step % self.config.save_steps == 0 and is_main_process():
|
|
try:
|
|
if self.checkpoint_manager:
|
|
state_dict = {
|
|
'model_state_dict': self.model.state_dict(),
|
|
'optimizer_state_dict': self.optimizer.state_dict(),
|
|
'scheduler_state_dict': self.scheduler.state_dict() if self.scheduler else None,
|
|
'global_step': self.global_step,
|
|
'current_epoch': epoch,
|
|
'config': vars(self.config)
|
|
}
|
|
self.checkpoint_manager.save(state_dict, self.global_step)
|
|
except Exception as e:
|
|
logger.error(f"Failed to save checkpoint at step {self.global_step}: {e}")
|
|
# Continue training even if checkpoint fails
|
|
|
|
except Exception as e:
|
|
# Catch-all for any unexpected errors
|
|
logger.error(f"Unexpected error in training step {step}: {e}")
|
|
failed_batches.append((step, 'unknown', str(e)))
|
|
if len(failed_batches) > max_failures * 2:
|
|
raise RuntimeError(f"Training failed with too many errors: {len(failed_batches)}")
|
|
continue
|
|
|
|
# Log epoch summary
|
|
if failed_batches:
|
|
logger.warning(f"Epoch {epoch + 1} completed with {len(failed_batches)} failed batches")
|
|
for step, stage, error in failed_batches[:5]: # Show first 5 failures
|
|
logger.debug(f" Batch {step} failed at {stage}: {error}")
|
|
|
|
# Compute epoch metrics
|
|
epoch_metrics['loss'] = epoch_loss / max(num_batches - len(failed_batches), 1)
|
|
epoch_metrics['failed_batches'] = len(failed_batches)
|
|
epoch_metrics['success_rate'] = (num_batches - len(failed_batches)) / num_batches if num_batches > 0 else 0.0
|
|
|
|
return epoch_metrics
|
|
|
|
def evaluate(self, dataloader: DataLoader) -> Dict[str, float]:
|
|
"""Evaluate the model"""
|
|
self.model.eval()
|
|
|
|
total_loss = 0
|
|
num_batches = 0
|
|
all_metrics = {}
|
|
|
|
with torch.no_grad():
|
|
for batch in tqdm(dataloader, desc="Evaluating", disable=not is_main_process()):
|
|
batch = self._prepare_batch(batch)
|
|
|
|
# Compute metrics
|
|
metrics = self.evaluate_step(batch)
|
|
|
|
# Accumulate metrics
|
|
for key, value in metrics.items():
|
|
if key not in all_metrics:
|
|
all_metrics[key] = 0
|
|
all_metrics[key] += value
|
|
|
|
num_batches += 1
|
|
|
|
# Average metrics
|
|
for key in all_metrics:
|
|
all_metrics[key] /= num_batches
|
|
|
|
# Gather metrics across devices
|
|
if self.is_distributed:
|
|
all_metrics = reduce_dict(all_metrics)
|
|
|
|
return all_metrics
|
|
|
|
def _prepare_batch(self, batch: Dict[str, Any]) -> Dict[str, Any]:
|
|
"""Move tensor values in batch dict to the target device, leave non-tensors untouched."""
|
|
moved = {}
|
|
for k, v in batch.items():
|
|
if isinstance(v, torch.Tensor):
|
|
moved[k] = v.to(self.device, non_blocking=True)
|
|
else:
|
|
moved[k] = v # keep lists / strings / ints as-is
|
|
return moved
|
|
|
|
def _clear_device_cache(self):
|
|
"""Clear device memory cache with support for CUDA and Ascend NPU"""
|
|
try:
|
|
if torch.cuda.is_available():
|
|
torch.cuda.empty_cache()
|
|
else:
|
|
# Try Ascend support if available
|
|
try:
|
|
from utils.ascend_support import clear_ascend_cache
|
|
clear_ascend_cache()
|
|
except ImportError:
|
|
# Fallback to manual NPU check
|
|
try:
|
|
npu = getattr(torch, 'npu', None)
|
|
if npu is not None and hasattr(npu, 'is_available') and npu.is_available():
|
|
npu.empty_cache()
|
|
except Exception:
|
|
pass
|
|
except Exception as e:
|
|
logger.debug(f"Failed to clear device cache: {e}")
|
|
|
|
def _is_better_metric(self, current: float) -> bool:
|
|
"""Check if current metric is better than best"""
|
|
if self.config.greater_is_better:
|
|
return current > self.best_metric
|
|
else:
|
|
return current < self.best_metric
|
|
|
|
def _save_checkpoint(self, is_best: bool = False, metrics: Optional[Dict[str, float]] = None):
|
|
"""Save model checkpoint"""
|
|
if not is_main_process() or not self.checkpoint_manager:
|
|
return
|
|
|
|
checkpoint = {
|
|
'model_state_dict': self.model.module.state_dict() if self.is_distributed else self.model.state_dict(), # type: ignore[attr-defined]
|
|
'optimizer_state_dict': self.optimizer.state_dict(),
|
|
'scheduler_state_dict': self.scheduler.state_dict() if self.scheduler else None,
|
|
'config': self.config.to_dict(),
|
|
'global_step': self.global_step,
|
|
'current_epoch': self.current_epoch,
|
|
'best_metric': self.best_metric
|
|
}
|
|
|
|
if self.scaler:
|
|
checkpoint['scaler_state_dict'] = self.scaler.state_dict()
|
|
|
|
self.checkpoint_manager.save(
|
|
checkpoint,
|
|
step=self.global_step,
|
|
is_best=is_best,
|
|
metrics=metrics
|
|
)
|
|
|
|
def _save_training_summary(self):
|
|
"""Save training summary"""
|
|
if not is_main_process():
|
|
return
|
|
|
|
summary = {
|
|
'config': self.config.to_dict(),
|
|
'final_metrics': self.training_history[-1] if self.training_history else {},
|
|
'best_metric': self.best_metric,
|
|
'total_steps': self.global_step,
|
|
'total_epochs': self.current_epoch,
|
|
'training_time': datetime.now().isoformat()
|
|
}
|
|
|
|
summary_path = os.path.join(self.config.output_dir, 'training_summary.json')
|
|
with open(summary_path, 'w') as f:
|
|
json.dump(summary, f, indent=2)
|
|
|
|
def load_checkpoint(self, checkpoint_path: str):
|
|
"""Load checkpoint"""
|
|
logger.info(f"Loading checkpoint from {checkpoint_path}")
|
|
|
|
checkpoint = torch.load(checkpoint_path, map_location=self.device)
|
|
|
|
# Load model state
|
|
if self.is_distributed:
|
|
self.model.module.load_state_dict(checkpoint['model_state_dict']) # type: ignore[attr-defined]
|
|
else:
|
|
self.model.load_state_dict(checkpoint['model_state_dict']) # type: ignore[attr-defined]
|
|
|
|
# Load optimizer state
|
|
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
|
|
|
# Load scheduler state
|
|
if self.scheduler and checkpoint.get('scheduler_state_dict'):
|
|
self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
|
|
|
|
# Load scaler state
|
|
if self.scaler and checkpoint.get('scaler_state_dict'):
|
|
self.scaler.load_state_dict(checkpoint['scaler_state_dict'])
|
|
|
|
# Load training state
|
|
self.global_step = checkpoint.get('global_step', 0)
|
|
self.current_epoch = checkpoint.get('current_epoch', 0)
|
|
self.best_metric = checkpoint.get('best_metric', float('-inf'))
|
|
|
|
logger.info(f"Resumed from step {self.global_step}, epoch {self.current_epoch}")
|