Revised init commit
This commit is contained in:
105
test_checkpoint_fix.py
Normal file
105
test_checkpoint_fix.py
Normal file
@@ -0,0 +1,105 @@
|
||||
#!/usr/bin/env python
|
||||
"""Minimal test to verify checkpoint saving fixes"""
|
||||
|
||||
import torch
|
||||
import tempfile
|
||||
import os
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
import sys
|
||||
|
||||
# Add parent directory to path
|
||||
sys.path.insert(0, str(Path(__file__).parent))
|
||||
|
||||
from utils.checkpoint import CheckpointManager
|
||||
|
||||
|
||||
def test_checkpoint_save():
|
||||
"""Test basic checkpoint save functionality"""
|
||||
print("🧪 Testing checkpoint save fixes...")
|
||||
|
||||
# Create temporary directory
|
||||
temp_dir = tempfile.mkdtemp()
|
||||
print(f"Using temp directory: {temp_dir}")
|
||||
|
||||
try:
|
||||
# Create checkpoint manager
|
||||
manager = CheckpointManager(
|
||||
output_dir=temp_dir,
|
||||
save_total_limit=2
|
||||
)
|
||||
|
||||
# Create a simple model state dict
|
||||
model_state = {
|
||||
'layer1.weight': torch.randn(10, 5),
|
||||
'layer1.bias': torch.randn(10),
|
||||
'layer2.weight': torch.randn(3, 10),
|
||||
'layer2.bias': torch.randn(3)
|
||||
}
|
||||
|
||||
# Create optimizer state dict
|
||||
optimizer_state = {
|
||||
'param_groups': [{'lr': 0.001, 'momentum': 0.9}],
|
||||
'state': {}
|
||||
}
|
||||
|
||||
# Test save at step 1 (should work)
|
||||
print("Testing checkpoint save at step 1...")
|
||||
state_dict = {
|
||||
'model_state_dict': model_state,
|
||||
'optimizer_state_dict': optimizer_state,
|
||||
'global_step': 1,
|
||||
'current_epoch': 0
|
||||
}
|
||||
|
||||
checkpoint_path = manager.save(state_dict, step=1)
|
||||
print(f"✅ Successfully saved checkpoint at step 1: {checkpoint_path}")
|
||||
|
||||
# Test save at step 0 (should be avoided by trainer logic)
|
||||
print("Testing checkpoint save at step 0...")
|
||||
state_dict['global_step'] = 0
|
||||
|
||||
checkpoint_path = manager.save(state_dict, step=0)
|
||||
print(f"✅ Successfully saved checkpoint at step 0: {checkpoint_path}")
|
||||
print("Note: Step 0 save worked, but trainer logic should prevent this")
|
||||
|
||||
# Verify files exist
|
||||
model_file = os.path.join(checkpoint_path, 'pytorch_model.bin')
|
||||
if os.path.exists(model_file):
|
||||
size = os.path.getsize(model_file) / 1024 # KB
|
||||
print(f"✅ Model file exists and has size: {size:.1f}KB")
|
||||
else:
|
||||
print("❌ Model file not found")
|
||||
return False
|
||||
|
||||
# Test loading
|
||||
print("Testing checkpoint load...")
|
||||
loaded_state = manager.load_checkpoint(checkpoint_path)
|
||||
|
||||
if 'model_state_dict' in loaded_state:
|
||||
print("✅ Successfully loaded checkpoint")
|
||||
return True
|
||||
else:
|
||||
print("❌ Failed to load model state")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Test failed: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return False
|
||||
finally:
|
||||
# Cleanup
|
||||
if os.path.exists(temp_dir):
|
||||
shutil.rmtree(temp_dir)
|
||||
print("🧹 Cleaned up temporary files")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
success = test_checkpoint_save()
|
||||
if success:
|
||||
print("🎉 All tests passed! Checkpoint fixes are working.")
|
||||
sys.exit(0)
|
||||
else:
|
||||
print("💥 Tests failed!")
|
||||
sys.exit(1)
|
||||
Reference in New Issue
Block a user