bge_finetune/test_checkpoint_fix.py
2025-07-22 17:49:16 +08:00

105 lines
3.2 KiB
Python

#!/usr/bin/env python
"""Minimal test to verify checkpoint saving fixes"""
import torch
import tempfile
import os
import shutil
from pathlib import Path
import sys
# Add parent directory to path
sys.path.insert(0, str(Path(__file__).parent))
from utils.checkpoint import CheckpointManager
def test_checkpoint_save():
"""Test basic checkpoint save functionality"""
print("🧪 Testing checkpoint save fixes...")
# Create temporary directory
temp_dir = tempfile.mkdtemp()
print(f"Using temp directory: {temp_dir}")
try:
# Create checkpoint manager
manager = CheckpointManager(
output_dir=temp_dir,
save_total_limit=2
)
# Create a simple model state dict
model_state = {
'layer1.weight': torch.randn(10, 5),
'layer1.bias': torch.randn(10),
'layer2.weight': torch.randn(3, 10),
'layer2.bias': torch.randn(3)
}
# Create optimizer state dict
optimizer_state = {
'param_groups': [{'lr': 0.001, 'momentum': 0.9}],
'state': {}
}
# Test save at step 1 (should work)
print("Testing checkpoint save at step 1...")
state_dict = {
'model_state_dict': model_state,
'optimizer_state_dict': optimizer_state,
'global_step': 1,
'current_epoch': 0
}
checkpoint_path = manager.save(state_dict, step=1)
print(f"✅ Successfully saved checkpoint at step 1: {checkpoint_path}")
# Test save at step 0 (should be avoided by trainer logic)
print("Testing checkpoint save at step 0...")
state_dict['global_step'] = 0
checkpoint_path = manager.save(state_dict, step=0)
print(f"✅ Successfully saved checkpoint at step 0: {checkpoint_path}")
print("Note: Step 0 save worked, but trainer logic should prevent this")
# Verify files exist
model_file = os.path.join(checkpoint_path, 'pytorch_model.bin')
if os.path.exists(model_file):
size = os.path.getsize(model_file) / 1024 # KB
print(f"✅ Model file exists and has size: {size:.1f}KB")
else:
print("❌ Model file not found")
return False
# Test loading
print("Testing checkpoint load...")
loaded_state = manager.load_checkpoint(checkpoint_path)
if 'model_state_dict' in loaded_state:
print("✅ Successfully loaded checkpoint")
return True
else:
print("❌ Failed to load model state")
return False
except Exception as e:
print(f"❌ Test failed: {e}")
import traceback
traceback.print_exc()
return False
finally:
# Cleanup
if os.path.exists(temp_dir):
shutil.rmtree(temp_dir)
print("🧹 Cleaned up temporary files")
if __name__ == "__main__":
success = test_checkpoint_save()
if success:
print("🎉 All tests passed! Checkpoint fixes are working.")
sys.exit(0)
else:
print("💥 Tests failed!")
sys.exit(1)