bge_finetune/utils/checkpoint.py
2025-07-23 15:17:14 +08:00

550 lines
20 KiB
Python

import os
import torch
import json
import shutil
from typing import Dict, Any, Optional, List
import logging
from datetime import datetime
import glob
logger = logging.getLogger(__name__)
class CheckpointManager:
"""
Manages model checkpoints with automatic cleanup and best model tracking
"""
def __init__(
self,
output_dir: str,
save_total_limit: int = 3,
best_model_dir: str = "best_model",
save_optimizer: bool = True,
save_scheduler: bool = True
):
self.output_dir = output_dir
self.save_total_limit = save_total_limit
self.best_model_dir = os.path.join(output_dir, best_model_dir)
self.save_optimizer = save_optimizer
self.save_scheduler = save_scheduler
# Create directories
os.makedirs(output_dir, exist_ok=True)
os.makedirs(self.best_model_dir, exist_ok=True)
# Track checkpoints
self.checkpoints = []
self._load_existing_checkpoints()
def _load_existing_checkpoints(self):
"""Load list of existing checkpoints"""
checkpoint_dirs = glob.glob(os.path.join(self.output_dir, "checkpoint-*"))
# Sort by step number
self.checkpoints = []
for cp_dir in checkpoint_dirs:
try:
step = int(cp_dir.split("-")[-1])
self.checkpoints.append((step, cp_dir))
except:
pass
self.checkpoints.sort(key=lambda x: x[0])
def save(
self,
state_dict: Dict[str, Any],
step: int,
is_best: bool = False,
metrics: Optional[Dict[str, float]] = None
) -> str:
"""
Save checkpoint with atomic writes and error handling
Args:
state_dict: State dictionary containing model, optimizer, etc.
step: Current training step
is_best: Whether this is the best model so far
metrics: Optional metrics to save with checkpoint
Returns:
Path to saved checkpoint
"""
checkpoint_dir = os.path.join(self.output_dir, f"checkpoint-{step}")
temp_dir = os.path.join(self.output_dir, f"checkpoint-{step}-tmp")
try:
# Create temporary directory for atomic write
os.makedirs(temp_dir, exist_ok=True)
# Save individual components
checkpoint_files = {}
# Save model with error handling
if 'model_state_dict' in state_dict:
model_path = os.path.join(temp_dir, 'pytorch_model.bin')
# Add extra safety check for model state dict
model_state = state_dict['model_state_dict']
if model_state is None or (isinstance(model_state, dict) and len(model_state) == 0):
raise RuntimeError("Model state dict is empty")
self._safe_torch_save(model_state, model_path)
checkpoint_files['model'] = 'pytorch_model.bin'
# Save optimizer
if self.save_optimizer and 'optimizer_state_dict' in state_dict:
optimizer_path = os.path.join(temp_dir, 'optimizer.pt')
self._safe_torch_save(state_dict['optimizer_state_dict'], optimizer_path)
checkpoint_files['optimizer'] = 'optimizer.pt'
# Save scheduler
if self.save_scheduler and 'scheduler_state_dict' in state_dict:
scheduler_path = os.path.join(temp_dir, 'scheduler.pt')
self._safe_torch_save(state_dict['scheduler_state_dict'], scheduler_path)
checkpoint_files['scheduler'] = 'scheduler.pt'
# Save training state
training_state = {
'step': step,
'checkpoint_files': checkpoint_files,
'timestamp': datetime.now().isoformat()
}
# Add additional state info
for key in ['global_step', 'current_epoch', 'best_metric']:
if key in state_dict:
training_state[key] = state_dict[key]
if metrics:
training_state['metrics'] = metrics
# Save config if provided
if 'config' in state_dict:
config_path = os.path.join(temp_dir, 'config.json')
self._safe_json_save(state_dict['config'], config_path)
# Save training state
state_path = os.path.join(temp_dir, 'training_state.json')
self._safe_json_save(training_state, state_path)
# Atomic rename from temp to final directory
if os.path.exists(checkpoint_dir):
# Backup existing checkpoint
backup_dir = f"{checkpoint_dir}.backup"
if os.path.exists(backup_dir):
shutil.rmtree(backup_dir)
shutil.move(checkpoint_dir, backup_dir)
# Move temp directory to final location
shutil.move(temp_dir, checkpoint_dir)
# Remove backup if successful
if os.path.exists(f"{checkpoint_dir}.backup"):
shutil.rmtree(f"{checkpoint_dir}.backup")
# Add to checkpoint list
self.checkpoints.append((step, checkpoint_dir))
# Save best model
if is_best:
self._save_best_model(checkpoint_dir, step, metrics)
# Cleanup old checkpoints
self._cleanup_checkpoints()
logger.info(f"Saved checkpoint to {checkpoint_dir}")
return checkpoint_dir
except Exception as e:
logger.error(f"Failed to save checkpoint: {e}")
# Cleanup temp directory
if os.path.exists(temp_dir):
shutil.rmtree(temp_dir)
# Restore backup if exists
backup_dir = f"{checkpoint_dir}.backup"
if os.path.exists(backup_dir):
if os.path.exists(checkpoint_dir):
shutil.rmtree(checkpoint_dir)
shutil.move(backup_dir, checkpoint_dir)
raise RuntimeError(f"Checkpoint save failed: {e}")
def _save_best_model(
self,
checkpoint_dir: str,
step: int,
metrics: Optional[Dict[str, float]] = None
):
"""Save the best model"""
# Copy checkpoint to best model directory
for filename in os.listdir(checkpoint_dir):
src = os.path.join(checkpoint_dir, filename)
dst = os.path.join(self.best_model_dir, filename)
if os.path.isfile(src):
shutil.copy2(src, dst)
# Save best model info
best_info = {
'step': step,
'source_checkpoint': checkpoint_dir,
'timestamp': datetime.now().isoformat()
}
if metrics:
best_info['metrics'] = metrics
best_info_path = os.path.join(self.best_model_dir, 'best_model_info.json')
with open(best_info_path, 'w') as f:
json.dump(best_info, f, indent=2)
logger.info(f"Saved best model from step {step}")
def _cleanup_checkpoints(self):
"""Remove old checkpoints to maintain save_total_limit"""
if self.save_total_limit is not None and len(self.checkpoints) > self.save_total_limit:
# Remove oldest checkpoints
checkpoints_to_remove = self.checkpoints[:-self.save_total_limit]
for step, checkpoint_dir in checkpoints_to_remove:
if os.path.exists(checkpoint_dir):
shutil.rmtree(checkpoint_dir)
logger.info(f"Removed old checkpoint: {checkpoint_dir}")
# Update checkpoint list
self.checkpoints = self.checkpoints[-self.save_total_limit:]
def load_checkpoint(
self,
checkpoint_path: Optional[str] = None,
load_best: bool = False,
map_location: str = 'cpu',
strict: bool = True
) -> Dict[str, Any]:
"""
Load checkpoint with robust error handling
Args:
checkpoint_path: Path to specific checkpoint
load_best: Whether to load best model
map_location: Device to map tensors to
strict: Whether to strictly enforce state dict matching
Returns:
Loaded state dictionary
Raises:
RuntimeError: If checkpoint loading fails
"""
try:
if load_best:
checkpoint_path = self.best_model_dir
elif checkpoint_path is None:
# Load latest checkpoint
if self.checkpoints:
checkpoint_path = self.checkpoints[-1][1]
else:
raise ValueError("No checkpoints found")
# Ensure checkpoint_path is a string
if checkpoint_path is None:
raise ValueError("Checkpoint path is None")
checkpoint_path = str(checkpoint_path) # Ensure it's a string
if not os.path.exists(checkpoint_path):
raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}")
logger.info(f"Loading checkpoint from {checkpoint_path}")
# Load training state with error handling
state_dict = {}
state_path = os.path.join(checkpoint_path, 'training_state.json')
if os.path.exists(state_path):
try:
with open(state_path, 'r') as f:
training_state = json.load(f)
state_dict.update(training_state)
except Exception as e:
logger.warning(f"Failed to load training state: {e}")
training_state = {}
else:
training_state = {}
# Load model state
model_path = os.path.join(checkpoint_path, 'pytorch_model.bin')
if os.path.exists(model_path):
try:
state_dict['model_state_dict'] = torch.load(
model_path,
map_location=map_location,
weights_only=True # Security: prevent arbitrary code execution
)
except Exception as e:
# Try legacy loading without weights_only for backward compatibility
try:
state_dict['model_state_dict'] = torch.load(
model_path,
map_location=map_location
)
logger.warning("Loaded checkpoint using legacy mode (weights_only=False)")
except Exception as legacy_e:
raise RuntimeError(f"Failed to load model state: {e}, Legacy attempt: {legacy_e}")
else:
logger.warning(f"Model state not found at {model_path}")
# Load optimizer state
optimizer_path = os.path.join(checkpoint_path, 'optimizer.pt')
if os.path.exists(optimizer_path):
try:
state_dict['optimizer_state_dict'] = torch.load(
optimizer_path,
map_location=map_location,
weights_only=True
)
except Exception as e:
# Try legacy loading
try:
state_dict['optimizer_state_dict'] = torch.load(
optimizer_path,
map_location=map_location
)
except Exception as legacy_e:
logger.warning(f"Failed to load optimizer state: {e}")
# Load scheduler state
scheduler_path = os.path.join(checkpoint_path, 'scheduler.pt')
if os.path.exists(scheduler_path):
try:
state_dict['scheduler_state_dict'] = torch.load(
scheduler_path,
map_location=map_location,
weights_only=True
)
except Exception as e:
# Try legacy loading
try:
state_dict['scheduler_state_dict'] = torch.load(
scheduler_path,
map_location=map_location
)
except Exception as legacy_e:
logger.warning(f"Failed to load scheduler state: {e}")
# Load config if available
config_path = os.path.join(checkpoint_path, 'config.json')
if os.path.exists(config_path):
try:
with open(config_path, 'r') as f:
state_dict['config'] = json.load(f)
except Exception as e:
logger.warning(f"Failed to load config: {e}")
logger.info(f"Successfully loaded checkpoint from {checkpoint_path}")
return state_dict
except Exception as e:
logger.error(f"Failed to load checkpoint: {e}")
raise RuntimeError(f"Checkpoint loading failed: {e}")
def get_latest_checkpoint(self) -> Optional[str]:
"""Get path to latest checkpoint"""
if self.checkpoints:
return self.checkpoints[-1][1]
return None
def get_best_checkpoint(self) -> str:
"""Get path to best model checkpoint"""
return self.best_model_dir
def checkpoint_exists(self, step: int) -> bool:
"""Check if checkpoint exists for given step"""
checkpoint_dir = os.path.join(self.output_dir, f"checkpoint-{step}")
return os.path.exists(checkpoint_dir)
def _safe_torch_save(self, obj: Any, path: str):
"""Save torch object with robust error handling and fallback methods"""
temp_path = f"{path}.tmp"
# Check disk space before attempting save (require at least 1GB free)
try:
import shutil
free_space = shutil.disk_usage(os.path.dirname(path)).free
if free_space < 1024 * 1024 * 1024: # 1GB
raise RuntimeError(f"Insufficient disk space: {free_space / (1024**3):.1f}GB available")
except Exception as e:
logger.warning(f"Could not check disk space: {e}")
# Try multiple save methods with increasing robustness
save_methods = [
# Method 1: Standard PyTorch save with weights_only for security
lambda: torch.save(obj, temp_path, _use_new_zipfile_serialization=False),
# Method 2: Legacy PyTorch save without new zipfile
lambda: torch.save(obj, temp_path),
# Method 3: Save with pickle protocol 4 (more robust)
lambda: torch.save(obj, temp_path, pickle_protocol=4),
# Method 4: Force CPU and retry (in case of GPU memory issues)
lambda: self._cpu_fallback_save(obj, temp_path)
]
last_error = None
for i, save_method in enumerate(save_methods):
try:
# Clear GPU cache before saving to free up memory
if torch.cuda.is_available():
torch.cuda.empty_cache()
save_method()
# Verify file was created and has reasonable size
if os.path.exists(temp_path):
size = os.path.getsize(temp_path)
if size < 100: # File too small, likely corrupted
raise RuntimeError(f"Saved file too small: {size} bytes")
# Atomic rename
os.replace(temp_path, path)
logger.info(f"Successfully saved {path} using method {i+1} (size: {size/1024/1024:.1f}MB)")
return
else:
raise RuntimeError("Temp file was not created")
except Exception as e:
last_error = e
logger.warning(f"Save method {i+1} failed for {path}: {e}")
if os.path.exists(temp_path):
try:
os.remove(temp_path)
except:
pass
continue
# All methods failed
if os.path.exists(temp_path):
try:
os.remove(temp_path)
except:
pass
raise RuntimeError(f"Failed to save {path} after trying all methods: {last_error}")
def _cpu_fallback_save(self, obj: Any, path: str):
"""Fallback save method: move tensors to CPU and save"""
import copy
# Deep copy and move to CPU
if hasattr(obj, 'state_dict'):
# For models with state_dict method
cpu_obj = {k: v.cpu() if isinstance(v, torch.Tensor) else v
for k, v in obj.state_dict().items()}
elif isinstance(obj, dict):
# For state dict dictionaries
cpu_obj = {}
for k, v in obj.items():
if isinstance(v, torch.Tensor):
cpu_obj[k] = v.cpu()
elif isinstance(v, dict):
cpu_obj[k] = {k2: v2.cpu() if isinstance(v2, torch.Tensor) else v2
for k2, v2 in v.items()}
else:
cpu_obj[k] = v
else:
# For other objects
cpu_obj = obj
torch.save(cpu_obj, path, pickle_protocol=4)
def _safe_json_save(self, obj: Any, path: str):
"""Save JSON with atomic write"""
temp_path = f"{path}.tmp"
try:
with open(temp_path, 'w') as f:
json.dump(obj, f, indent=2)
# Atomic rename
os.replace(temp_path, path)
except Exception as e:
if os.path.exists(temp_path):
os.remove(temp_path)
raise RuntimeError(f"Failed to save {path}: {e}")
def save_model_for_inference(
model: torch.nn.Module,
tokenizer: Any,
output_dir: str,
model_name: str = "model",
save_config: bool = True,
additional_files: Optional[Dict[str, str]] = None
):
"""
Save model in a format ready for inference
Args:
model: The model to save
tokenizer: Tokenizer to save
output_dir: Output directory
model_name: Name for the model
save_config: Whether to save model config
additional_files: Additional files to include
"""
os.makedirs(output_dir, exist_ok=True)
# Save model
model_path = os.path.join(output_dir, f"{model_name}.pt")
torch.save(model.state_dict(), model_path)
logger.info(f"Saved model to {model_path}")
# Save tokenizer
if hasattr(tokenizer, 'save_pretrained'):
tokenizer.save_pretrained(output_dir)
logger.info(f"Saved tokenizer to {output_dir}")
# Save config if available
if save_config and hasattr(model, 'config'):
config_path = os.path.join(output_dir, 'config.json')
model.config.save_pretrained(output_dir) # type: ignore[attr-defined]
# Save additional files
if additional_files:
for filename, content in additional_files.items():
filepath = os.path.join(output_dir, filename)
with open(filepath, 'w') as f:
if isinstance(content, dict):
json.dump(content, f, indent=2)
else:
f.write(str(content))
# Create README
readme_content = f"""# {model_name} Model
This directory contains a fine-tuned model ready for inference.
## Files:
- `{model_name}.pt`: Model weights
- `tokenizer_config.json`: Tokenizer configuration
- `special_tokens_map.json`: Special tokens mapping
- `vocab.txt` or `tokenizer.json`: Vocabulary file
## Loading the model:
```python
import torch
from transformers import AutoTokenizer
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained('{output_dir}')
# Load model
model = YourModelClass() # Initialize your model class
model.load_state_dict(torch.load('{model_name}.pt'))
model.eval()
```
## Model Information:
- Model type: {model.__class__.__name__}
- Saved on: {datetime.now().isoformat()}
"""
readme_path = os.path.join(output_dir, 'README.md')
with open(readme_path, 'w') as f:
f.write(readme_content)
logger.info(f"Model saved for inference in {output_dir}")