init commit
This commit is contained in:
465
utils/checkpoint.py
Normal file
465
utils/checkpoint.py
Normal file
@@ -0,0 +1,465 @@
|
||||
import os
|
||||
import torch
|
||||
import json
|
||||
import shutil
|
||||
from typing import Dict, Any, Optional, List
|
||||
import logging
|
||||
from datetime import datetime
|
||||
import glob
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CheckpointManager:
|
||||
"""
|
||||
Manages model checkpoints with automatic cleanup and best model tracking
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
output_dir: str,
|
||||
save_total_limit: int = 3,
|
||||
best_model_dir: str = "best_model",
|
||||
save_optimizer: bool = True,
|
||||
save_scheduler: bool = True
|
||||
):
|
||||
self.output_dir = output_dir
|
||||
self.save_total_limit = save_total_limit
|
||||
self.best_model_dir = os.path.join(output_dir, best_model_dir)
|
||||
self.save_optimizer = save_optimizer
|
||||
self.save_scheduler = save_scheduler
|
||||
|
||||
# Create directories
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
os.makedirs(self.best_model_dir, exist_ok=True)
|
||||
|
||||
# Track checkpoints
|
||||
self.checkpoints = []
|
||||
self._load_existing_checkpoints()
|
||||
|
||||
def _load_existing_checkpoints(self):
|
||||
"""Load list of existing checkpoints"""
|
||||
checkpoint_dirs = glob.glob(os.path.join(self.output_dir, "checkpoint-*"))
|
||||
|
||||
# Sort by step number
|
||||
self.checkpoints = []
|
||||
for cp_dir in checkpoint_dirs:
|
||||
try:
|
||||
step = int(cp_dir.split("-")[-1])
|
||||
self.checkpoints.append((step, cp_dir))
|
||||
except:
|
||||
pass
|
||||
|
||||
self.checkpoints.sort(key=lambda x: x[0])
|
||||
|
||||
def save(
|
||||
self,
|
||||
state_dict: Dict[str, Any],
|
||||
step: int,
|
||||
is_best: bool = False,
|
||||
metrics: Optional[Dict[str, float]] = None
|
||||
) -> str:
|
||||
"""
|
||||
Save checkpoint with atomic writes and error handling
|
||||
|
||||
Args:
|
||||
state_dict: State dictionary containing model, optimizer, etc.
|
||||
step: Current training step
|
||||
is_best: Whether this is the best model so far
|
||||
metrics: Optional metrics to save with checkpoint
|
||||
|
||||
Returns:
|
||||
Path to saved checkpoint
|
||||
"""
|
||||
checkpoint_dir = os.path.join(self.output_dir, f"checkpoint-{step}")
|
||||
temp_dir = os.path.join(self.output_dir, f"checkpoint-{step}-tmp")
|
||||
|
||||
try:
|
||||
# Create temporary directory for atomic write
|
||||
os.makedirs(temp_dir, exist_ok=True)
|
||||
|
||||
# Save individual components
|
||||
checkpoint_files = {}
|
||||
|
||||
# Save model with error handling
|
||||
if 'model_state_dict' in state_dict:
|
||||
model_path = os.path.join(temp_dir, 'pytorch_model.bin')
|
||||
self._safe_torch_save(state_dict['model_state_dict'], model_path)
|
||||
checkpoint_files['model'] = 'pytorch_model.bin'
|
||||
|
||||
# Save optimizer
|
||||
if self.save_optimizer and 'optimizer_state_dict' in state_dict:
|
||||
optimizer_path = os.path.join(temp_dir, 'optimizer.pt')
|
||||
self._safe_torch_save(state_dict['optimizer_state_dict'], optimizer_path)
|
||||
checkpoint_files['optimizer'] = 'optimizer.pt'
|
||||
|
||||
# Save scheduler
|
||||
if self.save_scheduler and 'scheduler_state_dict' in state_dict:
|
||||
scheduler_path = os.path.join(temp_dir, 'scheduler.pt')
|
||||
self._safe_torch_save(state_dict['scheduler_state_dict'], scheduler_path)
|
||||
checkpoint_files['scheduler'] = 'scheduler.pt'
|
||||
|
||||
# Save training state
|
||||
training_state = {
|
||||
'step': step,
|
||||
'checkpoint_files': checkpoint_files,
|
||||
'timestamp': datetime.now().isoformat()
|
||||
}
|
||||
|
||||
# Add additional state info
|
||||
for key in ['global_step', 'current_epoch', 'best_metric']:
|
||||
if key in state_dict:
|
||||
training_state[key] = state_dict[key]
|
||||
|
||||
if metrics:
|
||||
training_state['metrics'] = metrics
|
||||
|
||||
# Save config if provided
|
||||
if 'config' in state_dict:
|
||||
config_path = os.path.join(temp_dir, 'config.json')
|
||||
self._safe_json_save(state_dict['config'], config_path)
|
||||
|
||||
# Save training state
|
||||
state_path = os.path.join(temp_dir, 'training_state.json')
|
||||
self._safe_json_save(training_state, state_path)
|
||||
|
||||
# Atomic rename from temp to final directory
|
||||
if os.path.exists(checkpoint_dir):
|
||||
# Backup existing checkpoint
|
||||
backup_dir = f"{checkpoint_dir}.backup"
|
||||
if os.path.exists(backup_dir):
|
||||
shutil.rmtree(backup_dir)
|
||||
shutil.move(checkpoint_dir, backup_dir)
|
||||
|
||||
# Move temp directory to final location
|
||||
shutil.move(temp_dir, checkpoint_dir)
|
||||
|
||||
# Remove backup if successful
|
||||
if os.path.exists(f"{checkpoint_dir}.backup"):
|
||||
shutil.rmtree(f"{checkpoint_dir}.backup")
|
||||
|
||||
# Add to checkpoint list
|
||||
self.checkpoints.append((step, checkpoint_dir))
|
||||
|
||||
# Save best model
|
||||
if is_best:
|
||||
self._save_best_model(checkpoint_dir, step, metrics)
|
||||
|
||||
# Cleanup old checkpoints
|
||||
self._cleanup_checkpoints()
|
||||
|
||||
logger.info(f"Saved checkpoint to {checkpoint_dir}")
|
||||
return checkpoint_dir
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save checkpoint: {e}")
|
||||
# Cleanup temp directory
|
||||
if os.path.exists(temp_dir):
|
||||
shutil.rmtree(temp_dir)
|
||||
# Restore backup if exists
|
||||
backup_dir = f"{checkpoint_dir}.backup"
|
||||
if os.path.exists(backup_dir):
|
||||
if os.path.exists(checkpoint_dir):
|
||||
shutil.rmtree(checkpoint_dir)
|
||||
shutil.move(backup_dir, checkpoint_dir)
|
||||
raise RuntimeError(f"Checkpoint save failed: {e}")
|
||||
|
||||
def _save_best_model(
|
||||
self,
|
||||
checkpoint_dir: str,
|
||||
step: int,
|
||||
metrics: Optional[Dict[str, float]] = None
|
||||
):
|
||||
"""Save the best model"""
|
||||
# Copy checkpoint to best model directory
|
||||
for filename in os.listdir(checkpoint_dir):
|
||||
src = os.path.join(checkpoint_dir, filename)
|
||||
dst = os.path.join(self.best_model_dir, filename)
|
||||
if os.path.isfile(src):
|
||||
shutil.copy2(src, dst)
|
||||
|
||||
# Save best model info
|
||||
best_info = {
|
||||
'step': step,
|
||||
'source_checkpoint': checkpoint_dir,
|
||||
'timestamp': datetime.now().isoformat()
|
||||
}
|
||||
if metrics:
|
||||
best_info['metrics'] = metrics
|
||||
|
||||
best_info_path = os.path.join(self.best_model_dir, 'best_model_info.json')
|
||||
with open(best_info_path, 'w') as f:
|
||||
json.dump(best_info, f, indent=2)
|
||||
|
||||
logger.info(f"Saved best model from step {step}")
|
||||
|
||||
def _cleanup_checkpoints(self):
|
||||
"""Remove old checkpoints to maintain save_total_limit"""
|
||||
if self.save_total_limit is not None and len(self.checkpoints) > self.save_total_limit:
|
||||
# Remove oldest checkpoints
|
||||
checkpoints_to_remove = self.checkpoints[:-self.save_total_limit]
|
||||
|
||||
for step, checkpoint_dir in checkpoints_to_remove:
|
||||
if os.path.exists(checkpoint_dir):
|
||||
shutil.rmtree(checkpoint_dir)
|
||||
logger.info(f"Removed old checkpoint: {checkpoint_dir}")
|
||||
|
||||
# Update checkpoint list
|
||||
self.checkpoints = self.checkpoints[-self.save_total_limit:]
|
||||
|
||||
def load_checkpoint(
|
||||
self,
|
||||
checkpoint_path: Optional[str] = None,
|
||||
load_best: bool = False,
|
||||
map_location: str = 'cpu',
|
||||
strict: bool = True
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Load checkpoint with robust error handling
|
||||
|
||||
Args:
|
||||
checkpoint_path: Path to specific checkpoint
|
||||
load_best: Whether to load best model
|
||||
map_location: Device to map tensors to
|
||||
strict: Whether to strictly enforce state dict matching
|
||||
|
||||
Returns:
|
||||
Loaded state dictionary
|
||||
|
||||
Raises:
|
||||
RuntimeError: If checkpoint loading fails
|
||||
"""
|
||||
try:
|
||||
if load_best:
|
||||
checkpoint_path = self.best_model_dir
|
||||
elif checkpoint_path is None:
|
||||
# Load latest checkpoint
|
||||
if self.checkpoints:
|
||||
checkpoint_path = self.checkpoints[-1][1]
|
||||
else:
|
||||
raise ValueError("No checkpoints found")
|
||||
|
||||
# Ensure checkpoint_path is a string
|
||||
if checkpoint_path is None:
|
||||
raise ValueError("Checkpoint path is None")
|
||||
|
||||
checkpoint_path = str(checkpoint_path) # Ensure it's a string
|
||||
|
||||
if not os.path.exists(checkpoint_path):
|
||||
raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}")
|
||||
|
||||
logger.info(f"Loading checkpoint from {checkpoint_path}")
|
||||
|
||||
# Load training state with error handling
|
||||
state_dict = {}
|
||||
state_path = os.path.join(checkpoint_path, 'training_state.json')
|
||||
if os.path.exists(state_path):
|
||||
try:
|
||||
with open(state_path, 'r') as f:
|
||||
training_state = json.load(f)
|
||||
state_dict.update(training_state)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to load training state: {e}")
|
||||
training_state = {}
|
||||
else:
|
||||
training_state = {}
|
||||
|
||||
# Load model state
|
||||
model_path = os.path.join(checkpoint_path, 'pytorch_model.bin')
|
||||
if os.path.exists(model_path):
|
||||
try:
|
||||
state_dict['model_state_dict'] = torch.load(
|
||||
model_path,
|
||||
map_location=map_location,
|
||||
weights_only=True # Security: prevent arbitrary code execution
|
||||
)
|
||||
except Exception as e:
|
||||
# Try legacy loading without weights_only for backward compatibility
|
||||
try:
|
||||
state_dict['model_state_dict'] = torch.load(
|
||||
model_path,
|
||||
map_location=map_location
|
||||
)
|
||||
logger.warning("Loaded checkpoint using legacy mode (weights_only=False)")
|
||||
except Exception as legacy_e:
|
||||
raise RuntimeError(f"Failed to load model state: {e}, Legacy attempt: {legacy_e}")
|
||||
else:
|
||||
logger.warning(f"Model state not found at {model_path}")
|
||||
|
||||
# Load optimizer state
|
||||
optimizer_path = os.path.join(checkpoint_path, 'optimizer.pt')
|
||||
if os.path.exists(optimizer_path):
|
||||
try:
|
||||
state_dict['optimizer_state_dict'] = torch.load(
|
||||
optimizer_path,
|
||||
map_location=map_location,
|
||||
weights_only=True
|
||||
)
|
||||
except Exception as e:
|
||||
# Try legacy loading
|
||||
try:
|
||||
state_dict['optimizer_state_dict'] = torch.load(
|
||||
optimizer_path,
|
||||
map_location=map_location
|
||||
)
|
||||
except Exception as legacy_e:
|
||||
logger.warning(f"Failed to load optimizer state: {e}")
|
||||
|
||||
# Load scheduler state
|
||||
scheduler_path = os.path.join(checkpoint_path, 'scheduler.pt')
|
||||
if os.path.exists(scheduler_path):
|
||||
try:
|
||||
state_dict['scheduler_state_dict'] = torch.load(
|
||||
scheduler_path,
|
||||
map_location=map_location,
|
||||
weights_only=True
|
||||
)
|
||||
except Exception as e:
|
||||
# Try legacy loading
|
||||
try:
|
||||
state_dict['scheduler_state_dict'] = torch.load(
|
||||
scheduler_path,
|
||||
map_location=map_location
|
||||
)
|
||||
except Exception as legacy_e:
|
||||
logger.warning(f"Failed to load scheduler state: {e}")
|
||||
|
||||
# Load config if available
|
||||
config_path = os.path.join(checkpoint_path, 'config.json')
|
||||
if os.path.exists(config_path):
|
||||
try:
|
||||
with open(config_path, 'r') as f:
|
||||
state_dict['config'] = json.load(f)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to load config: {e}")
|
||||
|
||||
logger.info(f"Successfully loaded checkpoint from {checkpoint_path}")
|
||||
return state_dict
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load checkpoint: {e}")
|
||||
raise RuntimeError(f"Checkpoint loading failed: {e}")
|
||||
|
||||
def get_latest_checkpoint(self) -> Optional[str]:
|
||||
"""Get path to latest checkpoint"""
|
||||
if self.checkpoints:
|
||||
return self.checkpoints[-1][1]
|
||||
return None
|
||||
|
||||
def get_best_checkpoint(self) -> str:
|
||||
"""Get path to best model checkpoint"""
|
||||
return self.best_model_dir
|
||||
|
||||
def checkpoint_exists(self, step: int) -> bool:
|
||||
"""Check if checkpoint exists for given step"""
|
||||
checkpoint_dir = os.path.join(self.output_dir, f"checkpoint-{step}")
|
||||
return os.path.exists(checkpoint_dir)
|
||||
|
||||
def _safe_torch_save(self, obj: Any, path: str):
|
||||
"""Save torch object with error handling"""
|
||||
temp_path = f"{path}.tmp"
|
||||
try:
|
||||
torch.save(obj, temp_path)
|
||||
# Atomic rename
|
||||
os.replace(temp_path, path)
|
||||
except Exception as e:
|
||||
if os.path.exists(temp_path):
|
||||
os.remove(temp_path)
|
||||
raise RuntimeError(f"Failed to save {path}: {e}")
|
||||
|
||||
def _safe_json_save(self, obj: Any, path: str):
|
||||
"""Save JSON with atomic write"""
|
||||
temp_path = f"{path}.tmp"
|
||||
try:
|
||||
with open(temp_path, 'w') as f:
|
||||
json.dump(obj, f, indent=2)
|
||||
# Atomic rename
|
||||
os.replace(temp_path, path)
|
||||
except Exception as e:
|
||||
if os.path.exists(temp_path):
|
||||
os.remove(temp_path)
|
||||
raise RuntimeError(f"Failed to save {path}: {e}")
|
||||
|
||||
|
||||
def save_model_for_inference(
|
||||
model: torch.nn.Module,
|
||||
tokenizer: Any,
|
||||
output_dir: str,
|
||||
model_name: str = "model",
|
||||
save_config: bool = True,
|
||||
additional_files: Optional[Dict[str, str]] = None
|
||||
):
|
||||
"""
|
||||
Save model in a format ready for inference
|
||||
|
||||
Args:
|
||||
model: The model to save
|
||||
tokenizer: Tokenizer to save
|
||||
output_dir: Output directory
|
||||
model_name: Name for the model
|
||||
save_config: Whether to save model config
|
||||
additional_files: Additional files to include
|
||||
"""
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
# Save model
|
||||
model_path = os.path.join(output_dir, f"{model_name}.pt")
|
||||
torch.save(model.state_dict(), model_path)
|
||||
logger.info(f"Saved model to {model_path}")
|
||||
|
||||
# Save tokenizer
|
||||
if hasattr(tokenizer, 'save_pretrained'):
|
||||
tokenizer.save_pretrained(output_dir)
|
||||
logger.info(f"Saved tokenizer to {output_dir}")
|
||||
|
||||
# Save config if available
|
||||
if save_config and hasattr(model, 'config'):
|
||||
config_path = os.path.join(output_dir, 'config.json')
|
||||
model.config.save_pretrained(output_dir) # type: ignore[attr-defined]
|
||||
|
||||
# Save additional files
|
||||
if additional_files:
|
||||
for filename, content in additional_files.items():
|
||||
filepath = os.path.join(output_dir, filename)
|
||||
with open(filepath, 'w') as f:
|
||||
if isinstance(content, dict):
|
||||
json.dump(content, f, indent=2)
|
||||
else:
|
||||
f.write(str(content))
|
||||
|
||||
# Create README
|
||||
readme_content = f"""# {model_name} Model
|
||||
|
||||
This directory contains a fine-tuned model ready for inference.
|
||||
|
||||
## Files:
|
||||
- `{model_name}.pt`: Model weights
|
||||
- `tokenizer_config.json`: Tokenizer configuration
|
||||
- `special_tokens_map.json`: Special tokens mapping
|
||||
- `vocab.txt` or `tokenizer.json`: Vocabulary file
|
||||
|
||||
## Loading the model:
|
||||
|
||||
```python
|
||||
import torch
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
# Load tokenizer
|
||||
tokenizer = AutoTokenizer.from_pretrained('{output_dir}')
|
||||
|
||||
# Load model
|
||||
model = YourModelClass() # Initialize your model class
|
||||
model.load_state_dict(torch.load('{model_name}.pt'))
|
||||
model.eval()
|
||||
```
|
||||
|
||||
## Model Information:
|
||||
- Model type: {model.__class__.__name__}
|
||||
- Saved on: {datetime.now().isoformat()}
|
||||
"""
|
||||
|
||||
readme_path = os.path.join(output_dir, 'README.md')
|
||||
with open(readme_path, 'w') as f:
|
||||
f.write(readme_content)
|
||||
|
||||
logger.info(f"Model saved for inference in {output_dir}")
|
||||
Reference in New Issue
Block a user