105 lines
3.2 KiB
Python
105 lines
3.2 KiB
Python
#!/usr/bin/env python
|
|
"""Minimal test to verify checkpoint saving fixes"""
|
|
|
|
import torch
|
|
import tempfile
|
|
import os
|
|
import shutil
|
|
from pathlib import Path
|
|
import sys
|
|
|
|
# Add parent directory to path
|
|
sys.path.insert(0, str(Path(__file__).parent))
|
|
|
|
from utils.checkpoint import CheckpointManager
|
|
|
|
|
|
def test_checkpoint_save():
|
|
"""Test basic checkpoint save functionality"""
|
|
print("🧪 Testing checkpoint save fixes...")
|
|
|
|
# Create temporary directory
|
|
temp_dir = tempfile.mkdtemp()
|
|
print(f"Using temp directory: {temp_dir}")
|
|
|
|
try:
|
|
# Create checkpoint manager
|
|
manager = CheckpointManager(
|
|
output_dir=temp_dir,
|
|
save_total_limit=2
|
|
)
|
|
|
|
# Create a simple model state dict
|
|
model_state = {
|
|
'layer1.weight': torch.randn(10, 5),
|
|
'layer1.bias': torch.randn(10),
|
|
'layer2.weight': torch.randn(3, 10),
|
|
'layer2.bias': torch.randn(3)
|
|
}
|
|
|
|
# Create optimizer state dict
|
|
optimizer_state = {
|
|
'param_groups': [{'lr': 0.001, 'momentum': 0.9}],
|
|
'state': {}
|
|
}
|
|
|
|
# Test save at step 1 (should work)
|
|
print("Testing checkpoint save at step 1...")
|
|
state_dict = {
|
|
'model_state_dict': model_state,
|
|
'optimizer_state_dict': optimizer_state,
|
|
'global_step': 1,
|
|
'current_epoch': 0
|
|
}
|
|
|
|
checkpoint_path = manager.save(state_dict, step=1)
|
|
print(f"✅ Successfully saved checkpoint at step 1: {checkpoint_path}")
|
|
|
|
# Test save at step 0 (should be avoided by trainer logic)
|
|
print("Testing checkpoint save at step 0...")
|
|
state_dict['global_step'] = 0
|
|
|
|
checkpoint_path = manager.save(state_dict, step=0)
|
|
print(f"✅ Successfully saved checkpoint at step 0: {checkpoint_path}")
|
|
print("Note: Step 0 save worked, but trainer logic should prevent this")
|
|
|
|
# Verify files exist
|
|
model_file = os.path.join(checkpoint_path, 'pytorch_model.bin')
|
|
if os.path.exists(model_file):
|
|
size = os.path.getsize(model_file) / 1024 # KB
|
|
print(f"✅ Model file exists and has size: {size:.1f}KB")
|
|
else:
|
|
print("❌ Model file not found")
|
|
return False
|
|
|
|
# Test loading
|
|
print("Testing checkpoint load...")
|
|
loaded_state = manager.load_checkpoint(checkpoint_path)
|
|
|
|
if 'model_state_dict' in loaded_state:
|
|
print("✅ Successfully loaded checkpoint")
|
|
return True
|
|
else:
|
|
print("❌ Failed to load model state")
|
|
return False
|
|
|
|
except Exception as e:
|
|
print(f"❌ Test failed: {e}")
|
|
import traceback
|
|
traceback.print_exc()
|
|
return False
|
|
finally:
|
|
# Cleanup
|
|
if os.path.exists(temp_dir):
|
|
shutil.rmtree(temp_dir)
|
|
print("🧹 Cleaned up temporary files")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
success = test_checkpoint_save()
|
|
if success:
|
|
print("🎉 All tests passed! Checkpoint fixes are working.")
|
|
sys.exit(0)
|
|
else:
|
|
print("💥 Tests failed!")
|
|
sys.exit(1) |