550 lines
20 KiB
Python
550 lines
20 KiB
Python
import os
|
|
import torch
|
|
import json
|
|
import shutil
|
|
from typing import Dict, Any, Optional, List
|
|
import logging
|
|
from datetime import datetime
|
|
import glob
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class CheckpointManager:
|
|
"""
|
|
Manages model checkpoints with automatic cleanup and best model tracking
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
output_dir: str,
|
|
save_total_limit: int = 3,
|
|
best_model_dir: str = "best_model",
|
|
save_optimizer: bool = True,
|
|
save_scheduler: bool = True
|
|
):
|
|
self.output_dir = output_dir
|
|
self.save_total_limit = save_total_limit
|
|
self.best_model_dir = os.path.join(output_dir, best_model_dir)
|
|
self.save_optimizer = save_optimizer
|
|
self.save_scheduler = save_scheduler
|
|
|
|
# Create directories
|
|
os.makedirs(output_dir, exist_ok=True)
|
|
os.makedirs(self.best_model_dir, exist_ok=True)
|
|
|
|
# Track checkpoints
|
|
self.checkpoints = []
|
|
self._load_existing_checkpoints()
|
|
|
|
def _load_existing_checkpoints(self):
|
|
"""Load list of existing checkpoints"""
|
|
checkpoint_dirs = glob.glob(os.path.join(self.output_dir, "checkpoint-*"))
|
|
|
|
# Sort by step number
|
|
self.checkpoints = []
|
|
for cp_dir in checkpoint_dirs:
|
|
try:
|
|
step = int(cp_dir.split("-")[-1])
|
|
self.checkpoints.append((step, cp_dir))
|
|
except:
|
|
pass
|
|
|
|
self.checkpoints.sort(key=lambda x: x[0])
|
|
|
|
def save(
|
|
self,
|
|
state_dict: Dict[str, Any],
|
|
step: int,
|
|
is_best: bool = False,
|
|
metrics: Optional[Dict[str, float]] = None
|
|
) -> str:
|
|
"""
|
|
Save checkpoint with atomic writes and error handling
|
|
|
|
Args:
|
|
state_dict: State dictionary containing model, optimizer, etc.
|
|
step: Current training step
|
|
is_best: Whether this is the best model so far
|
|
metrics: Optional metrics to save with checkpoint
|
|
|
|
Returns:
|
|
Path to saved checkpoint
|
|
"""
|
|
checkpoint_dir = os.path.join(self.output_dir, f"checkpoint-{step}")
|
|
temp_dir = os.path.join(self.output_dir, f"checkpoint-{step}-tmp")
|
|
|
|
try:
|
|
# Create temporary directory for atomic write
|
|
os.makedirs(temp_dir, exist_ok=True)
|
|
|
|
# Save individual components
|
|
checkpoint_files = {}
|
|
|
|
# Save model with error handling
|
|
if 'model_state_dict' in state_dict:
|
|
model_path = os.path.join(temp_dir, 'pytorch_model.bin')
|
|
# Add extra safety check for model state dict
|
|
model_state = state_dict['model_state_dict']
|
|
if not model_state:
|
|
raise RuntimeError("Model state dict is empty")
|
|
|
|
self._safe_torch_save(model_state, model_path)
|
|
checkpoint_files['model'] = 'pytorch_model.bin'
|
|
|
|
# Save optimizer
|
|
if self.save_optimizer and 'optimizer_state_dict' in state_dict:
|
|
optimizer_path = os.path.join(temp_dir, 'optimizer.pt')
|
|
self._safe_torch_save(state_dict['optimizer_state_dict'], optimizer_path)
|
|
checkpoint_files['optimizer'] = 'optimizer.pt'
|
|
|
|
# Save scheduler
|
|
if self.save_scheduler and 'scheduler_state_dict' in state_dict:
|
|
scheduler_path = os.path.join(temp_dir, 'scheduler.pt')
|
|
self._safe_torch_save(state_dict['scheduler_state_dict'], scheduler_path)
|
|
checkpoint_files['scheduler'] = 'scheduler.pt'
|
|
|
|
# Save training state
|
|
training_state = {
|
|
'step': step,
|
|
'checkpoint_files': checkpoint_files,
|
|
'timestamp': datetime.now().isoformat()
|
|
}
|
|
|
|
# Add additional state info
|
|
for key in ['global_step', 'current_epoch', 'best_metric']:
|
|
if key in state_dict:
|
|
training_state[key] = state_dict[key]
|
|
|
|
if metrics:
|
|
training_state['metrics'] = metrics
|
|
|
|
# Save config if provided
|
|
if 'config' in state_dict:
|
|
config_path = os.path.join(temp_dir, 'config.json')
|
|
self._safe_json_save(state_dict['config'], config_path)
|
|
|
|
# Save training state
|
|
state_path = os.path.join(temp_dir, 'training_state.json')
|
|
self._safe_json_save(training_state, state_path)
|
|
|
|
# Atomic rename from temp to final directory
|
|
if os.path.exists(checkpoint_dir):
|
|
# Backup existing checkpoint
|
|
backup_dir = f"{checkpoint_dir}.backup"
|
|
if os.path.exists(backup_dir):
|
|
shutil.rmtree(backup_dir)
|
|
shutil.move(checkpoint_dir, backup_dir)
|
|
|
|
# Move temp directory to final location
|
|
shutil.move(temp_dir, checkpoint_dir)
|
|
|
|
# Remove backup if successful
|
|
if os.path.exists(f"{checkpoint_dir}.backup"):
|
|
shutil.rmtree(f"{checkpoint_dir}.backup")
|
|
|
|
# Add to checkpoint list
|
|
self.checkpoints.append((step, checkpoint_dir))
|
|
|
|
# Save best model
|
|
if is_best:
|
|
self._save_best_model(checkpoint_dir, step, metrics)
|
|
|
|
# Cleanup old checkpoints
|
|
self._cleanup_checkpoints()
|
|
|
|
logger.info(f"Saved checkpoint to {checkpoint_dir}")
|
|
return checkpoint_dir
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to save checkpoint: {e}")
|
|
# Cleanup temp directory
|
|
if os.path.exists(temp_dir):
|
|
shutil.rmtree(temp_dir)
|
|
# Restore backup if exists
|
|
backup_dir = f"{checkpoint_dir}.backup"
|
|
if os.path.exists(backup_dir):
|
|
if os.path.exists(checkpoint_dir):
|
|
shutil.rmtree(checkpoint_dir)
|
|
shutil.move(backup_dir, checkpoint_dir)
|
|
raise RuntimeError(f"Checkpoint save failed: {e}")
|
|
|
|
def _save_best_model(
|
|
self,
|
|
checkpoint_dir: str,
|
|
step: int,
|
|
metrics: Optional[Dict[str, float]] = None
|
|
):
|
|
"""Save the best model"""
|
|
# Copy checkpoint to best model directory
|
|
for filename in os.listdir(checkpoint_dir):
|
|
src = os.path.join(checkpoint_dir, filename)
|
|
dst = os.path.join(self.best_model_dir, filename)
|
|
if os.path.isfile(src):
|
|
shutil.copy2(src, dst)
|
|
|
|
# Save best model info
|
|
best_info = {
|
|
'step': step,
|
|
'source_checkpoint': checkpoint_dir,
|
|
'timestamp': datetime.now().isoformat()
|
|
}
|
|
if metrics:
|
|
best_info['metrics'] = metrics
|
|
|
|
best_info_path = os.path.join(self.best_model_dir, 'best_model_info.json')
|
|
with open(best_info_path, 'w') as f:
|
|
json.dump(best_info, f, indent=2)
|
|
|
|
logger.info(f"Saved best model from step {step}")
|
|
|
|
def _cleanup_checkpoints(self):
|
|
"""Remove old checkpoints to maintain save_total_limit"""
|
|
if self.save_total_limit is not None and len(self.checkpoints) > self.save_total_limit:
|
|
# Remove oldest checkpoints
|
|
checkpoints_to_remove = self.checkpoints[:-self.save_total_limit]
|
|
|
|
for step, checkpoint_dir in checkpoints_to_remove:
|
|
if os.path.exists(checkpoint_dir):
|
|
shutil.rmtree(checkpoint_dir)
|
|
logger.info(f"Removed old checkpoint: {checkpoint_dir}")
|
|
|
|
# Update checkpoint list
|
|
self.checkpoints = self.checkpoints[-self.save_total_limit:]
|
|
|
|
def load_checkpoint(
|
|
self,
|
|
checkpoint_path: Optional[str] = None,
|
|
load_best: bool = False,
|
|
map_location: str = 'cpu',
|
|
strict: bool = True
|
|
) -> Dict[str, Any]:
|
|
"""
|
|
Load checkpoint with robust error handling
|
|
|
|
Args:
|
|
checkpoint_path: Path to specific checkpoint
|
|
load_best: Whether to load best model
|
|
map_location: Device to map tensors to
|
|
strict: Whether to strictly enforce state dict matching
|
|
|
|
Returns:
|
|
Loaded state dictionary
|
|
|
|
Raises:
|
|
RuntimeError: If checkpoint loading fails
|
|
"""
|
|
try:
|
|
if load_best:
|
|
checkpoint_path = self.best_model_dir
|
|
elif checkpoint_path is None:
|
|
# Load latest checkpoint
|
|
if self.checkpoints:
|
|
checkpoint_path = self.checkpoints[-1][1]
|
|
else:
|
|
raise ValueError("No checkpoints found")
|
|
|
|
# Ensure checkpoint_path is a string
|
|
if checkpoint_path is None:
|
|
raise ValueError("Checkpoint path is None")
|
|
|
|
checkpoint_path = str(checkpoint_path) # Ensure it's a string
|
|
|
|
if not os.path.exists(checkpoint_path):
|
|
raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}")
|
|
|
|
logger.info(f"Loading checkpoint from {checkpoint_path}")
|
|
|
|
# Load training state with error handling
|
|
state_dict = {}
|
|
state_path = os.path.join(checkpoint_path, 'training_state.json')
|
|
if os.path.exists(state_path):
|
|
try:
|
|
with open(state_path, 'r') as f:
|
|
training_state = json.load(f)
|
|
state_dict.update(training_state)
|
|
except Exception as e:
|
|
logger.warning(f"Failed to load training state: {e}")
|
|
training_state = {}
|
|
else:
|
|
training_state = {}
|
|
|
|
# Load model state
|
|
model_path = os.path.join(checkpoint_path, 'pytorch_model.bin')
|
|
if os.path.exists(model_path):
|
|
try:
|
|
state_dict['model_state_dict'] = torch.load(
|
|
model_path,
|
|
map_location=map_location,
|
|
weights_only=True # Security: prevent arbitrary code execution
|
|
)
|
|
except Exception as e:
|
|
# Try legacy loading without weights_only for backward compatibility
|
|
try:
|
|
state_dict['model_state_dict'] = torch.load(
|
|
model_path,
|
|
map_location=map_location
|
|
)
|
|
logger.warning("Loaded checkpoint using legacy mode (weights_only=False)")
|
|
except Exception as legacy_e:
|
|
raise RuntimeError(f"Failed to load model state: {e}, Legacy attempt: {legacy_e}")
|
|
else:
|
|
logger.warning(f"Model state not found at {model_path}")
|
|
|
|
# Load optimizer state
|
|
optimizer_path = os.path.join(checkpoint_path, 'optimizer.pt')
|
|
if os.path.exists(optimizer_path):
|
|
try:
|
|
state_dict['optimizer_state_dict'] = torch.load(
|
|
optimizer_path,
|
|
map_location=map_location,
|
|
weights_only=True
|
|
)
|
|
except Exception as e:
|
|
# Try legacy loading
|
|
try:
|
|
state_dict['optimizer_state_dict'] = torch.load(
|
|
optimizer_path,
|
|
map_location=map_location
|
|
)
|
|
except Exception as legacy_e:
|
|
logger.warning(f"Failed to load optimizer state: {e}")
|
|
|
|
# Load scheduler state
|
|
scheduler_path = os.path.join(checkpoint_path, 'scheduler.pt')
|
|
if os.path.exists(scheduler_path):
|
|
try:
|
|
state_dict['scheduler_state_dict'] = torch.load(
|
|
scheduler_path,
|
|
map_location=map_location,
|
|
weights_only=True
|
|
)
|
|
except Exception as e:
|
|
# Try legacy loading
|
|
try:
|
|
state_dict['scheduler_state_dict'] = torch.load(
|
|
scheduler_path,
|
|
map_location=map_location
|
|
)
|
|
except Exception as legacy_e:
|
|
logger.warning(f"Failed to load scheduler state: {e}")
|
|
|
|
# Load config if available
|
|
config_path = os.path.join(checkpoint_path, 'config.json')
|
|
if os.path.exists(config_path):
|
|
try:
|
|
with open(config_path, 'r') as f:
|
|
state_dict['config'] = json.load(f)
|
|
except Exception as e:
|
|
logger.warning(f"Failed to load config: {e}")
|
|
|
|
logger.info(f"Successfully loaded checkpoint from {checkpoint_path}")
|
|
return state_dict
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to load checkpoint: {e}")
|
|
raise RuntimeError(f"Checkpoint loading failed: {e}")
|
|
|
|
def get_latest_checkpoint(self) -> Optional[str]:
|
|
"""Get path to latest checkpoint"""
|
|
if self.checkpoints:
|
|
return self.checkpoints[-1][1]
|
|
return None
|
|
|
|
def get_best_checkpoint(self) -> str:
|
|
"""Get path to best model checkpoint"""
|
|
return self.best_model_dir
|
|
|
|
def checkpoint_exists(self, step: int) -> bool:
|
|
"""Check if checkpoint exists for given step"""
|
|
checkpoint_dir = os.path.join(self.output_dir, f"checkpoint-{step}")
|
|
return os.path.exists(checkpoint_dir)
|
|
|
|
def _safe_torch_save(self, obj: Any, path: str):
|
|
"""Save torch object with robust error handling and fallback methods"""
|
|
temp_path = f"{path}.tmp"
|
|
|
|
# Check disk space before attempting save (require at least 1GB free)
|
|
try:
|
|
import shutil
|
|
free_space = shutil.disk_usage(os.path.dirname(path)).free
|
|
if free_space < 1024 * 1024 * 1024: # 1GB
|
|
raise RuntimeError(f"Insufficient disk space: {free_space / (1024**3):.1f}GB available")
|
|
except Exception as e:
|
|
logger.warning(f"Could not check disk space: {e}")
|
|
|
|
# Try multiple save methods with increasing robustness
|
|
save_methods = [
|
|
# Method 1: Standard PyTorch save with weights_only for security
|
|
lambda: torch.save(obj, temp_path, _use_new_zipfile_serialization=False),
|
|
# Method 2: Legacy PyTorch save without new zipfile
|
|
lambda: torch.save(obj, temp_path),
|
|
# Method 3: Save with pickle protocol 4 (more robust)
|
|
lambda: torch.save(obj, temp_path, pickle_protocol=4),
|
|
# Method 4: Force CPU and retry (in case of GPU memory issues)
|
|
lambda: self._cpu_fallback_save(obj, temp_path)
|
|
]
|
|
|
|
last_error = None
|
|
for i, save_method in enumerate(save_methods):
|
|
try:
|
|
# Clear GPU cache before saving to free up memory
|
|
if torch.cuda.is_available():
|
|
torch.cuda.empty_cache()
|
|
|
|
save_method()
|
|
|
|
# Verify file was created and has reasonable size
|
|
if os.path.exists(temp_path):
|
|
size = os.path.getsize(temp_path)
|
|
if size < 100: # File too small, likely corrupted
|
|
raise RuntimeError(f"Saved file too small: {size} bytes")
|
|
|
|
# Atomic rename
|
|
os.replace(temp_path, path)
|
|
logger.info(f"Successfully saved {path} using method {i+1} (size: {size/1024/1024:.1f}MB)")
|
|
return
|
|
else:
|
|
raise RuntimeError("Temp file was not created")
|
|
|
|
except Exception as e:
|
|
last_error = e
|
|
logger.warning(f"Save method {i+1} failed for {path}: {e}")
|
|
if os.path.exists(temp_path):
|
|
try:
|
|
os.remove(temp_path)
|
|
except:
|
|
pass
|
|
continue
|
|
|
|
# All methods failed
|
|
if os.path.exists(temp_path):
|
|
try:
|
|
os.remove(temp_path)
|
|
except:
|
|
pass
|
|
raise RuntimeError(f"Failed to save {path} after trying all methods: {last_error}")
|
|
|
|
def _cpu_fallback_save(self, obj: Any, path: str):
|
|
"""Fallback save method: move tensors to CPU and save"""
|
|
import copy
|
|
|
|
# Deep copy and move to CPU
|
|
if hasattr(obj, 'state_dict'):
|
|
# For models with state_dict method
|
|
cpu_obj = {k: v.cpu() if isinstance(v, torch.Tensor) else v
|
|
for k, v in obj.state_dict().items()}
|
|
elif isinstance(obj, dict):
|
|
# For state dict dictionaries
|
|
cpu_obj = {}
|
|
for k, v in obj.items():
|
|
if isinstance(v, torch.Tensor):
|
|
cpu_obj[k] = v.cpu()
|
|
elif isinstance(v, dict):
|
|
cpu_obj[k] = {k2: v2.cpu() if isinstance(v2, torch.Tensor) else v2
|
|
for k2, v2 in v.items()}
|
|
else:
|
|
cpu_obj[k] = v
|
|
else:
|
|
# For other objects
|
|
cpu_obj = obj
|
|
|
|
torch.save(cpu_obj, path, pickle_protocol=4)
|
|
|
|
def _safe_json_save(self, obj: Any, path: str):
|
|
"""Save JSON with atomic write"""
|
|
temp_path = f"{path}.tmp"
|
|
try:
|
|
with open(temp_path, 'w') as f:
|
|
json.dump(obj, f, indent=2)
|
|
# Atomic rename
|
|
os.replace(temp_path, path)
|
|
except Exception as e:
|
|
if os.path.exists(temp_path):
|
|
os.remove(temp_path)
|
|
raise RuntimeError(f"Failed to save {path}: {e}")
|
|
|
|
|
|
def save_model_for_inference(
|
|
model: torch.nn.Module,
|
|
tokenizer: Any,
|
|
output_dir: str,
|
|
model_name: str = "model",
|
|
save_config: bool = True,
|
|
additional_files: Optional[Dict[str, str]] = None
|
|
):
|
|
"""
|
|
Save model in a format ready for inference
|
|
|
|
Args:
|
|
model: The model to save
|
|
tokenizer: Tokenizer to save
|
|
output_dir: Output directory
|
|
model_name: Name for the model
|
|
save_config: Whether to save model config
|
|
additional_files: Additional files to include
|
|
"""
|
|
os.makedirs(output_dir, exist_ok=True)
|
|
|
|
# Save model
|
|
model_path = os.path.join(output_dir, f"{model_name}.pt")
|
|
torch.save(model.state_dict(), model_path)
|
|
logger.info(f"Saved model to {model_path}")
|
|
|
|
# Save tokenizer
|
|
if hasattr(tokenizer, 'save_pretrained'):
|
|
tokenizer.save_pretrained(output_dir)
|
|
logger.info(f"Saved tokenizer to {output_dir}")
|
|
|
|
# Save config if available
|
|
if save_config and hasattr(model, 'config'):
|
|
config_path = os.path.join(output_dir, 'config.json')
|
|
model.config.save_pretrained(output_dir) # type: ignore[attr-defined]
|
|
|
|
# Save additional files
|
|
if additional_files:
|
|
for filename, content in additional_files.items():
|
|
filepath = os.path.join(output_dir, filename)
|
|
with open(filepath, 'w') as f:
|
|
if isinstance(content, dict):
|
|
json.dump(content, f, indent=2)
|
|
else:
|
|
f.write(str(content))
|
|
|
|
# Create README
|
|
readme_content = f"""# {model_name} Model
|
|
|
|
This directory contains a fine-tuned model ready for inference.
|
|
|
|
## Files:
|
|
- `{model_name}.pt`: Model weights
|
|
- `tokenizer_config.json`: Tokenizer configuration
|
|
- `special_tokens_map.json`: Special tokens mapping
|
|
- `vocab.txt` or `tokenizer.json`: Vocabulary file
|
|
|
|
## Loading the model:
|
|
|
|
```python
|
|
import torch
|
|
from transformers import AutoTokenizer
|
|
|
|
# Load tokenizer
|
|
tokenizer = AutoTokenizer.from_pretrained('{output_dir}')
|
|
|
|
# Load model
|
|
model = YourModelClass() # Initialize your model class
|
|
model.load_state_dict(torch.load('{model_name}.pt'))
|
|
model.eval()
|
|
```
|
|
|
|
## Model Information:
|
|
- Model type: {model.__class__.__name__}
|
|
- Saved on: {datetime.now().isoformat()}
|
|
"""
|
|
|
|
readme_path = os.path.join(output_dir, 'README.md')
|
|
with open(readme_path, 'w') as f:
|
|
f.write(readme_content)
|
|
|
|
logger.info(f"Model saved for inference in {output_dir}")
|