import os import torch import json import shutil from typing import Dict, Any, Optional, List import logging from datetime import datetime import glob logger = logging.getLogger(__name__) class CheckpointManager: """ Manages model checkpoints with automatic cleanup and best model tracking """ def __init__( self, output_dir: str, save_total_limit: int = 3, best_model_dir: str = "best_model", save_optimizer: bool = True, save_scheduler: bool = True ): self.output_dir = output_dir self.save_total_limit = save_total_limit self.best_model_dir = os.path.join(output_dir, best_model_dir) self.save_optimizer = save_optimizer self.save_scheduler = save_scheduler # Create directories os.makedirs(output_dir, exist_ok=True) os.makedirs(self.best_model_dir, exist_ok=True) # Track checkpoints self.checkpoints = [] self._load_existing_checkpoints() def _load_existing_checkpoints(self): """Load list of existing checkpoints""" checkpoint_dirs = glob.glob(os.path.join(self.output_dir, "checkpoint-*")) # Sort by step number self.checkpoints = [] for cp_dir in checkpoint_dirs: try: step = int(cp_dir.split("-")[-1]) self.checkpoints.append((step, cp_dir)) except: pass self.checkpoints.sort(key=lambda x: x[0]) def save( self, state_dict: Dict[str, Any], step: int, is_best: bool = False, metrics: Optional[Dict[str, float]] = None ) -> str: """ Save checkpoint with atomic writes and error handling Args: state_dict: State dictionary containing model, optimizer, etc. step: Current training step is_best: Whether this is the best model so far metrics: Optional metrics to save with checkpoint Returns: Path to saved checkpoint """ checkpoint_dir = os.path.join(self.output_dir, f"checkpoint-{step}") temp_dir = os.path.join(self.output_dir, f"checkpoint-{step}-tmp") try: # Create temporary directory for atomic write os.makedirs(temp_dir, exist_ok=True) # Save individual components checkpoint_files = {} # Save model with error handling if 'model_state_dict' in state_dict: model_path = os.path.join(temp_dir, 'pytorch_model.bin') self._safe_torch_save(state_dict['model_state_dict'], model_path) checkpoint_files['model'] = 'pytorch_model.bin' # Save optimizer if self.save_optimizer and 'optimizer_state_dict' in state_dict: optimizer_path = os.path.join(temp_dir, 'optimizer.pt') self._safe_torch_save(state_dict['optimizer_state_dict'], optimizer_path) checkpoint_files['optimizer'] = 'optimizer.pt' # Save scheduler if self.save_scheduler and 'scheduler_state_dict' in state_dict: scheduler_path = os.path.join(temp_dir, 'scheduler.pt') self._safe_torch_save(state_dict['scheduler_state_dict'], scheduler_path) checkpoint_files['scheduler'] = 'scheduler.pt' # Save training state training_state = { 'step': step, 'checkpoint_files': checkpoint_files, 'timestamp': datetime.now().isoformat() } # Add additional state info for key in ['global_step', 'current_epoch', 'best_metric']: if key in state_dict: training_state[key] = state_dict[key] if metrics: training_state['metrics'] = metrics # Save config if provided if 'config' in state_dict: config_path = os.path.join(temp_dir, 'config.json') self._safe_json_save(state_dict['config'], config_path) # Save training state state_path = os.path.join(temp_dir, 'training_state.json') self._safe_json_save(training_state, state_path) # Atomic rename from temp to final directory if os.path.exists(checkpoint_dir): # Backup existing checkpoint backup_dir = f"{checkpoint_dir}.backup" if os.path.exists(backup_dir): shutil.rmtree(backup_dir) shutil.move(checkpoint_dir, backup_dir) # Move temp directory to final location shutil.move(temp_dir, checkpoint_dir) # Remove backup if successful if os.path.exists(f"{checkpoint_dir}.backup"): shutil.rmtree(f"{checkpoint_dir}.backup") # Add to checkpoint list self.checkpoints.append((step, checkpoint_dir)) # Save best model if is_best: self._save_best_model(checkpoint_dir, step, metrics) # Cleanup old checkpoints self._cleanup_checkpoints() logger.info(f"Saved checkpoint to {checkpoint_dir}") return checkpoint_dir except Exception as e: logger.error(f"Failed to save checkpoint: {e}") # Cleanup temp directory if os.path.exists(temp_dir): shutil.rmtree(temp_dir) # Restore backup if exists backup_dir = f"{checkpoint_dir}.backup" if os.path.exists(backup_dir): if os.path.exists(checkpoint_dir): shutil.rmtree(checkpoint_dir) shutil.move(backup_dir, checkpoint_dir) raise RuntimeError(f"Checkpoint save failed: {e}") def _save_best_model( self, checkpoint_dir: str, step: int, metrics: Optional[Dict[str, float]] = None ): """Save the best model""" # Copy checkpoint to best model directory for filename in os.listdir(checkpoint_dir): src = os.path.join(checkpoint_dir, filename) dst = os.path.join(self.best_model_dir, filename) if os.path.isfile(src): shutil.copy2(src, dst) # Save best model info best_info = { 'step': step, 'source_checkpoint': checkpoint_dir, 'timestamp': datetime.now().isoformat() } if metrics: best_info['metrics'] = metrics best_info_path = os.path.join(self.best_model_dir, 'best_model_info.json') with open(best_info_path, 'w') as f: json.dump(best_info, f, indent=2) logger.info(f"Saved best model from step {step}") def _cleanup_checkpoints(self): """Remove old checkpoints to maintain save_total_limit""" if self.save_total_limit is not None and len(self.checkpoints) > self.save_total_limit: # Remove oldest checkpoints checkpoints_to_remove = self.checkpoints[:-self.save_total_limit] for step, checkpoint_dir in checkpoints_to_remove: if os.path.exists(checkpoint_dir): shutil.rmtree(checkpoint_dir) logger.info(f"Removed old checkpoint: {checkpoint_dir}") # Update checkpoint list self.checkpoints = self.checkpoints[-self.save_total_limit:] def load_checkpoint( self, checkpoint_path: Optional[str] = None, load_best: bool = False, map_location: str = 'cpu', strict: bool = True ) -> Dict[str, Any]: """ Load checkpoint with robust error handling Args: checkpoint_path: Path to specific checkpoint load_best: Whether to load best model map_location: Device to map tensors to strict: Whether to strictly enforce state dict matching Returns: Loaded state dictionary Raises: RuntimeError: If checkpoint loading fails """ try: if load_best: checkpoint_path = self.best_model_dir elif checkpoint_path is None: # Load latest checkpoint if self.checkpoints: checkpoint_path = self.checkpoints[-1][1] else: raise ValueError("No checkpoints found") # Ensure checkpoint_path is a string if checkpoint_path is None: raise ValueError("Checkpoint path is None") checkpoint_path = str(checkpoint_path) # Ensure it's a string if not os.path.exists(checkpoint_path): raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}") logger.info(f"Loading checkpoint from {checkpoint_path}") # Load training state with error handling state_dict = {} state_path = os.path.join(checkpoint_path, 'training_state.json') if os.path.exists(state_path): try: with open(state_path, 'r') as f: training_state = json.load(f) state_dict.update(training_state) except Exception as e: logger.warning(f"Failed to load training state: {e}") training_state = {} else: training_state = {} # Load model state model_path = os.path.join(checkpoint_path, 'pytorch_model.bin') if os.path.exists(model_path): try: state_dict['model_state_dict'] = torch.load( model_path, map_location=map_location, weights_only=True # Security: prevent arbitrary code execution ) except Exception as e: # Try legacy loading without weights_only for backward compatibility try: state_dict['model_state_dict'] = torch.load( model_path, map_location=map_location ) logger.warning("Loaded checkpoint using legacy mode (weights_only=False)") except Exception as legacy_e: raise RuntimeError(f"Failed to load model state: {e}, Legacy attempt: {legacy_e}") else: logger.warning(f"Model state not found at {model_path}") # Load optimizer state optimizer_path = os.path.join(checkpoint_path, 'optimizer.pt') if os.path.exists(optimizer_path): try: state_dict['optimizer_state_dict'] = torch.load( optimizer_path, map_location=map_location, weights_only=True ) except Exception as e: # Try legacy loading try: state_dict['optimizer_state_dict'] = torch.load( optimizer_path, map_location=map_location ) except Exception as legacy_e: logger.warning(f"Failed to load optimizer state: {e}") # Load scheduler state scheduler_path = os.path.join(checkpoint_path, 'scheduler.pt') if os.path.exists(scheduler_path): try: state_dict['scheduler_state_dict'] = torch.load( scheduler_path, map_location=map_location, weights_only=True ) except Exception as e: # Try legacy loading try: state_dict['scheduler_state_dict'] = torch.load( scheduler_path, map_location=map_location ) except Exception as legacy_e: logger.warning(f"Failed to load scheduler state: {e}") # Load config if available config_path = os.path.join(checkpoint_path, 'config.json') if os.path.exists(config_path): try: with open(config_path, 'r') as f: state_dict['config'] = json.load(f) except Exception as e: logger.warning(f"Failed to load config: {e}") logger.info(f"Successfully loaded checkpoint from {checkpoint_path}") return state_dict except Exception as e: logger.error(f"Failed to load checkpoint: {e}") raise RuntimeError(f"Checkpoint loading failed: {e}") def get_latest_checkpoint(self) -> Optional[str]: """Get path to latest checkpoint""" if self.checkpoints: return self.checkpoints[-1][1] return None def get_best_checkpoint(self) -> str: """Get path to best model checkpoint""" return self.best_model_dir def checkpoint_exists(self, step: int) -> bool: """Check if checkpoint exists for given step""" checkpoint_dir = os.path.join(self.output_dir, f"checkpoint-{step}") return os.path.exists(checkpoint_dir) def _safe_torch_save(self, obj: Any, path: str): """Save torch object with error handling""" temp_path = f"{path}.tmp" try: torch.save(obj, temp_path) # Atomic rename os.replace(temp_path, path) except Exception as e: if os.path.exists(temp_path): os.remove(temp_path) raise RuntimeError(f"Failed to save {path}: {e}") def _safe_json_save(self, obj: Any, path: str): """Save JSON with atomic write""" temp_path = f"{path}.tmp" try: with open(temp_path, 'w') as f: json.dump(obj, f, indent=2) # Atomic rename os.replace(temp_path, path) except Exception as e: if os.path.exists(temp_path): os.remove(temp_path) raise RuntimeError(f"Failed to save {path}: {e}") def save_model_for_inference( model: torch.nn.Module, tokenizer: Any, output_dir: str, model_name: str = "model", save_config: bool = True, additional_files: Optional[Dict[str, str]] = None ): """ Save model in a format ready for inference Args: model: The model to save tokenizer: Tokenizer to save output_dir: Output directory model_name: Name for the model save_config: Whether to save model config additional_files: Additional files to include """ os.makedirs(output_dir, exist_ok=True) # Save model model_path = os.path.join(output_dir, f"{model_name}.pt") torch.save(model.state_dict(), model_path) logger.info(f"Saved model to {model_path}") # Save tokenizer if hasattr(tokenizer, 'save_pretrained'): tokenizer.save_pretrained(output_dir) logger.info(f"Saved tokenizer to {output_dir}") # Save config if available if save_config and hasattr(model, 'config'): config_path = os.path.join(output_dir, 'config.json') model.config.save_pretrained(output_dir) # type: ignore[attr-defined] # Save additional files if additional_files: for filename, content in additional_files.items(): filepath = os.path.join(output_dir, filename) with open(filepath, 'w') as f: if isinstance(content, dict): json.dump(content, f, indent=2) else: f.write(str(content)) # Create README readme_content = f"""# {model_name} Model This directory contains a fine-tuned model ready for inference. ## Files: - `{model_name}.pt`: Model weights - `tokenizer_config.json`: Tokenizer configuration - `special_tokens_map.json`: Special tokens mapping - `vocab.txt` or `tokenizer.json`: Vocabulary file ## Loading the model: ```python import torch from transformers import AutoTokenizer # Load tokenizer tokenizer = AutoTokenizer.from_pretrained('{output_dir}') # Load model model = YourModelClass() # Initialize your model class model.load_state_dict(torch.load('{model_name}.pt')) model.eval() ``` ## Model Information: - Model type: {model.__class__.__name__} - Saved on: {datetime.now().isoformat()} """ readme_path = os.path.join(output_dir, 'README.md') with open(readme_path, 'w') as f: f.write(readme_content) logger.info(f"Model saved for inference in {output_dir}")