import torch import torch.nn as nn from torch.utils.data import DataLoader from torch.optim import Optimizer from torch.optim.lr_scheduler import _LRScheduler from typing import Dict, Optional, List, Tuple, Any, Union, cast import os import json import logging from tqdm import tqdm from abc import ABC, abstractmethod import time from datetime import datetime import numpy as np from torch.cuda.amp.autocast_mode import autocast from torch.cuda.amp.grad_scaler import GradScaler from torch.nn.parallel import DataParallel, DistributedDataParallel as DDP from torch.nn.utils.clip_grad import clip_grad_norm_ from collections import defaultdict from utils.checkpoint import CheckpointManager from utils.logging import MetricsLogger, TensorBoardLogger, ProgressLogger from utils.distributed import ( setup_distributed, cleanup_distributed, is_main_process, reduce_dict, synchronize, get_rank, get_world_size ) from config.base_config import BaseConfig try: from transformers import get_linear_schedule_with_warmup except ImportError: get_linear_schedule_with_warmup = None logger = logging.getLogger(__name__) class BaseTrainer(ABC): """ Abstract base trainer class for BGE models Notes: - All parameters (config, model, optimizer, scheduler, device, etc.) should be passed from config-driven training scripts. - This class does not load config directly; pass config values from your main script for consistency. """ def __init__( self, config: BaseConfig, model: nn.Module, optimizer: Optimizer, scheduler: Optional[_LRScheduler] = None, device: Optional[torch.device] = None ): self.config = config self.model = model self.optimizer = optimizer self.scheduler = scheduler # Setup device if device is None: self.device = torch.device(config.device) else: self.device = device # Move model to device self.model.to(self.device) # Setup distributed training if needed self.is_distributed = config.local_rank != -1 if self.is_distributed: setup_distributed() self.model = self._setup_distributed_model() # Setup mixed precision training self.use_amp = config.fp16 or config.bf16 if self.use_amp and config.bf16: # Use bfloat16 for better numerical stability on RTX5090/A100 self.amp_dtype = torch.bfloat16 else: self.amp_dtype = torch.float16 self.scaler = GradScaler() if self.use_amp and not config.bf16 else None # Memory optimization settings self.empty_cache_steps = getattr(config, 'empty_cache_steps', 100) self.gradient_checkpointing = getattr(config, 'gradient_checkpointing', False) # Setup logging self.metrics_logger = MetricsLogger(config.output_dir) if is_main_process() else None self.tb_logger = TensorBoardLogger( os.path.join(config.logging_dir, 'tensorboard'), enabled=is_main_process() ) self.progress_logger = None # Will be initialized in train() # Setup checkpointing self.checkpoint_manager = CheckpointManager( output_dir=config.output_dir, save_total_limit=config.save_total_limit ) if is_main_process() else None # Training state self.global_step = 0 self.current_epoch = 0 self.best_metric = float('-inf') if config.greater_is_better else float('inf') self.training_history = [] def _setup_distributed_model(self) -> nn.Module: """Setup model for distributed training""" from torch.nn.parallel import DistributedDataParallel as DDP model = DDP( self.model, device_ids=[self.config.local_rank], output_device=self.config.local_rank, find_unused_parameters=True ) return model @abstractmethod def compute_loss(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor: """Compute loss for a batch - to be implemented by subclasses""" pass @abstractmethod def evaluate_step(self, batch: Dict[str, torch.Tensor]) -> Dict[str, float]: """Evaluate a single batch - to be implemented by subclasses""" pass def train( self, train_dataloader: DataLoader, eval_dataloader: Optional[DataLoader] = None, num_epochs: Optional[int] = None ) -> Dict[str, Any]: """ Main training loop with robust error handling, config-driven parameters, and progress tracking. """ try: num_epochs = num_epochs or self.config.num_train_epochs total_steps = len(train_dataloader) * num_epochs self.progress_logger = ProgressLogger( total_steps=total_steps, log_interval=self.config.logging_steps ) if is_main_process(): logger.info("***** Running training *****") logger.info(f" Num examples = {len(train_dataloader.dataset)}") # type: ignore[attr-defined] logger.info(f" Num epochs = {num_epochs}") logger.info(f" Batch size = {self.config.per_device_train_batch_size}") logger.info(f" Gradient accumulation steps = {self.config.gradient_accumulation_steps}") logger.info(f" Total optimization steps = {total_steps}") logger.info(f" Device = {self.device}") if self.is_distributed: logger.info(f" Distributed training with {get_world_size()} GPUs") for epoch in range(self.current_epoch, num_epochs): self.current_epoch = epoch epoch_start_time = time.time() train_metrics = self.train_epoch(train_dataloader, epoch) # Add metrics to training history epoch_metrics = { 'epoch': epoch + 1, 'train_metrics': train_metrics, 'global_step': self.global_step, 'learning_rate': self.optimizer.param_groups[0]['lr'] if self.optimizer else 0, } if eval_dataloader is not None and (epoch + 1) % self.config.eval_steps == 0: eval_metrics = self.evaluate(eval_dataloader) epoch_metrics['eval_metrics'] = eval_metrics current_metric = eval_metrics.get(self.config.metric_for_best_model, 0) is_best = self._is_better_metric(current_metric) if is_best: self.best_metric = current_metric if is_main_process() and self.checkpoint_manager: self._save_checkpoint(is_best=is_best, metrics=eval_metrics) if is_main_process(): logger.info(f"Epoch {epoch + 1} evaluation: {eval_metrics}") # Add to training history self.training_history.append(epoch_metrics) # Save checkpoint at end of epoch (if save_steps allows and we've done training) if (is_main_process() and self.config.save_steps > 0 and self.global_step > 0 and (epoch + 1) % self.config.save_steps == 0): self._save_checkpoint(metrics=train_metrics) epoch_time = time.time() - epoch_start_time if is_main_process(): logger.info(f"Epoch {epoch + 1} completed in {epoch_time:.2f} seconds") logger.info(f"Epoch {epoch + 1} metrics: {train_metrics}") if is_main_process(): # Only save final checkpoint if we've actually trained (global_step > 0) if self.global_step > 0: self._save_checkpoint(metrics=train_metrics) self._save_training_summary() # Close TensorBoard logger self.tb_logger.close() if self.is_distributed: cleanup_distributed() return self.training_history # type: ignore[return-value] except Exception as e: logger.error(f"Training failed: {e}") raise def train_epoch( self, dataloader: DataLoader, epoch: int ) -> Dict[str, float]: """ Train for one epoch with comprehensive error handling and recovery Args: dataloader: Training data loader epoch: Current epoch number Returns: Dictionary of training metrics for the epoch Raises: RuntimeError: If training fails and cannot recover """ self.model.train() epoch_loss = 0.0 epoch_metrics = defaultdict(float) num_batches = len(dataloader) # Track failed batches for recovery failed_batches = [] max_failures = 3 # Training loop with error recovery progress_bar = tqdm( dataloader, desc=f"Epoch {epoch + 1}", disable=not is_main_process() ) for step, batch in enumerate(progress_bar): try: # Move batch to device with error handling try: batch = self._prepare_batch(batch) except Exception as e: logger.warning(f"Failed to prepare batch {step}: {e}") failed_batches.append((step, 'prepare', str(e))) if len(failed_batches) > max_failures: raise RuntimeError(f"Too many batch preparation failures: {len(failed_batches)}") continue # Compute loss with error recovery loss: torch.Tensor try: if self.use_amp: with autocast(enabled=True, dtype=self.amp_dtype): loss = self.compute_loss(batch) loss = cast(torch.Tensor, loss) else: loss = self.compute_loss(batch) loss = cast(torch.Tensor, loss) except Exception as e: logger.warning(f"Failed to compute loss for batch {step}: {e}") failed_batches.append((step, 'forward', str(e))) if len(failed_batches) > max_failures: raise RuntimeError(f"Too many forward pass failures: {len(failed_batches)}") continue # Scale loss for gradient accumulation loss = loss / self.config.gradient_accumulation_steps # Backward pass with error handling try: if self.scaler: self.scaler.scale(loss).backward() # type: ignore[attr-defined] else: loss.backward() # type: ignore[attr-defined] except Exception as e: logger.warning(f"Backward pass failed for batch {step}: {e}") # Clear gradients and continue self.optimizer.zero_grad() failed_batches.append((step, 'backward', str(e))) if len(failed_batches) > max_failures: raise RuntimeError(f"Too many backward pass failures: {len(failed_batches)}") continue # Update weights with error handling if (step + 1) % self.config.gradient_accumulation_steps == 0: try: # Gradient clipping if self.scaler: self.scaler.unscale_(self.optimizer) clip_grad_norm_( self.model.parameters(), self.config.max_grad_norm ) # Optimizer step if self.scaler: self.scaler.step(self.optimizer) self.scaler.update() else: self.optimizer.step() # Scheduler step if self.scheduler: self.scheduler.step() # Update global step self.global_step += 1 # Clear cache periodically for memory efficiency if self.empty_cache_steps > 0 and self.global_step % self.empty_cache_steps == 0: self._clear_device_cache() # Zero gradients self.optimizer.zero_grad() except Exception as e: logger.warning(f"Optimizer update failed at step {step}: {e}") # Reset optimizer state self.optimizer.zero_grad() if self.scaler: self.scaler.update() failed_batches.append((step, 'optimizer', str(e))) continue # Logging with error handling if self.global_step % self.config.logging_steps == 0: try: current_loss = loss.item() * self.config.gradient_accumulation_steps epoch_loss += current_loss # Update progress bar progress_bar.set_postfix({ 'loss': f'{current_loss:.4f}', 'lr': f'{self.scheduler.get_last_lr()[0]:.2e}' if self.scheduler else 'N/A', 'failures': len(failed_batches) }) # Log to TensorBoard if self.tb_logger: self.tb_logger.log_scalar('train/loss', current_loss, self.global_step) self.tb_logger.log_scalar('train/learning_rate', self.scheduler.get_last_lr()[0] if self.scheduler else self.optimizer.param_groups[0]['lr'], self.global_step) except Exception as e: logger.debug(f"Logging failed: {e}") # Checkpoint saving with error handling - avoid saving at step 0 if (self.config.save_steps > 0 and self.global_step > 0 and # Don't save at step 0 self.global_step % self.config.save_steps == 0 and is_main_process()): try: if self.checkpoint_manager: state_dict = { 'model_state_dict': self.model.state_dict(), 'optimizer_state_dict': self.optimizer.state_dict(), 'scheduler_state_dict': self.scheduler.state_dict() if self.scheduler else None, 'global_step': self.global_step, 'current_epoch': epoch, 'config': vars(self.config) } self.checkpoint_manager.save(state_dict, self.global_step) except Exception as e: logger.error(f"Failed to save checkpoint at step {self.global_step}: {e}") # Continue training even if checkpoint fails except Exception as e: # Catch-all for any unexpected errors logger.error(f"Unexpected error in training step {step}: {e}") failed_batches.append((step, 'unknown', str(e))) if len(failed_batches) > max_failures * 2: raise RuntimeError(f"Training failed with too many errors: {len(failed_batches)}") continue # Log epoch summary if failed_batches: logger.warning(f"Epoch {epoch + 1} completed with {len(failed_batches)} failed batches") for step, stage, error in failed_batches[:5]: # Show first 5 failures logger.debug(f" Batch {step} failed at {stage}: {error}") # Compute epoch metrics epoch_metrics['loss'] = epoch_loss / max(num_batches - len(failed_batches), 1) epoch_metrics['failed_batches'] = len(failed_batches) epoch_metrics['success_rate'] = (num_batches - len(failed_batches)) / num_batches if num_batches > 0 else 0.0 return epoch_metrics def evaluate(self, dataloader: DataLoader) -> Dict[str, float]: """Evaluate the model""" self.model.eval() total_loss = 0 num_batches = 0 all_metrics = {} with torch.no_grad(): for batch in tqdm(dataloader, desc="Evaluating", disable=not is_main_process()): batch = self._prepare_batch(batch) # Compute metrics metrics = self.evaluate_step(batch) # Accumulate metrics for key, value in metrics.items(): if key not in all_metrics: all_metrics[key] = 0 all_metrics[key] += value num_batches += 1 # Average metrics for key in all_metrics: all_metrics[key] /= num_batches # Gather metrics across devices if self.is_distributed: all_metrics = reduce_dict(all_metrics) return all_metrics def _prepare_batch(self, batch: Dict[str, Any]) -> Dict[str, Any]: """Move tensor values in batch dict to the target device, leave non-tensors untouched.""" moved = {} for k, v in batch.items(): if isinstance(v, torch.Tensor): moved[k] = v.to(self.device, non_blocking=True) else: moved[k] = v # keep lists / strings / ints as-is return moved def _clear_device_cache(self): """Clear device memory cache with support for CUDA and Ascend NPU""" try: if torch.cuda.is_available(): torch.cuda.empty_cache() else: # Try Ascend support if available try: from utils.ascend_support import clear_ascend_cache clear_ascend_cache() except ImportError: # Fallback to manual NPU check try: npu = getattr(torch, 'npu', None) if npu is not None and hasattr(npu, 'is_available') and npu.is_available(): npu.empty_cache() except Exception: pass except Exception as e: logger.debug(f"Failed to clear device cache: {e}") def _is_better_metric(self, current: float) -> bool: """Check if current metric is better than best""" if self.config.greater_is_better: return current > self.best_metric else: return current < self.best_metric def _save_checkpoint(self, is_best: bool = False, metrics: Optional[Dict[str, float]] = None): """Save model checkpoint""" if not is_main_process() or not self.checkpoint_manager: return checkpoint = { 'model_state_dict': self.model.module.state_dict() if self.is_distributed else self.model.state_dict(), # type: ignore[attr-defined] 'optimizer_state_dict': self.optimizer.state_dict(), 'scheduler_state_dict': self.scheduler.state_dict() if self.scheduler else None, 'config': self.config.to_dict(), 'global_step': self.global_step, 'current_epoch': self.current_epoch, 'best_metric': self.best_metric } if self.scaler: checkpoint['scaler_state_dict'] = self.scaler.state_dict() self.checkpoint_manager.save( checkpoint, step=self.global_step, is_best=is_best, metrics=metrics ) def _save_training_summary(self): """Save training summary""" if not is_main_process(): return summary = { 'config': self.config.to_dict(), 'final_metrics': self.training_history[-1] if self.training_history else {}, 'best_metric': self.best_metric, 'total_steps': self.global_step, 'total_epochs': self.current_epoch, 'training_time': datetime.now().isoformat() } summary_path = os.path.join(self.config.output_dir, 'training_summary.json') with open(summary_path, 'w') as f: json.dump(summary, f, indent=2) def load_checkpoint(self, checkpoint_path: str): """Load checkpoint""" logger.info(f"Loading checkpoint from {checkpoint_path}") checkpoint = torch.load(checkpoint_path, map_location=self.device) # Load model state if self.is_distributed: self.model.module.load_state_dict(checkpoint['model_state_dict']) # type: ignore[attr-defined] else: self.model.load_state_dict(checkpoint['model_state_dict']) # type: ignore[attr-defined] # Load optimizer state self.optimizer.load_state_dict(checkpoint['optimizer_state_dict']) # Load scheduler state if self.scheduler and checkpoint.get('scheduler_state_dict'): self.scheduler.load_state_dict(checkpoint['scheduler_state_dict']) # Load scaler state if self.scaler and checkpoint.get('scaler_state_dict'): self.scaler.load_state_dict(checkpoint['scaler_state_dict']) # Load training state self.global_step = checkpoint.get('global_step', 0) self.current_epoch = checkpoint.get('current_epoch', 0) self.best_metric = checkpoint.get('best_metric', float('-inf')) logger.info(f"Resumed from step {self.global_step}, epoch {self.current_epoch}")