#!/usr/bin/env python """Minimal test to verify checkpoint saving fixes""" import torch import tempfile import os import shutil from pathlib import Path import sys # Add parent directory to path sys.path.insert(0, str(Path(__file__).parent)) from utils.checkpoint import CheckpointManager def test_checkpoint_save(): """Test basic checkpoint save functionality""" print("๐Ÿงช Testing checkpoint save fixes...") # Create temporary directory temp_dir = tempfile.mkdtemp() print(f"Using temp directory: {temp_dir}") try: # Create checkpoint manager manager = CheckpointManager( output_dir=temp_dir, save_total_limit=2 ) # Create a simple model state dict model_state = { 'layer1.weight': torch.randn(10, 5), 'layer1.bias': torch.randn(10), 'layer2.weight': torch.randn(3, 10), 'layer2.bias': torch.randn(3) } # Create optimizer state dict optimizer_state = { 'param_groups': [{'lr': 0.001, 'momentum': 0.9}], 'state': {} } # Test save at step 1 (should work) print("Testing checkpoint save at step 1...") state_dict = { 'model_state_dict': model_state, 'optimizer_state_dict': optimizer_state, 'global_step': 1, 'current_epoch': 0 } checkpoint_path = manager.save(state_dict, step=1) print(f"โœ… Successfully saved checkpoint at step 1: {checkpoint_path}") # Test save at step 0 (should be avoided by trainer logic) print("Testing checkpoint save at step 0...") state_dict['global_step'] = 0 checkpoint_path = manager.save(state_dict, step=0) print(f"โœ… Successfully saved checkpoint at step 0: {checkpoint_path}") print("Note: Step 0 save worked, but trainer logic should prevent this") # Verify files exist model_file = os.path.join(checkpoint_path, 'pytorch_model.bin') if os.path.exists(model_file): size = os.path.getsize(model_file) / 1024 # KB print(f"โœ… Model file exists and has size: {size:.1f}KB") else: print("โŒ Model file not found") return False # Test loading print("Testing checkpoint load...") loaded_state = manager.load_checkpoint(checkpoint_path) if 'model_state_dict' in loaded_state: print("โœ… Successfully loaded checkpoint") return True else: print("โŒ Failed to load model state") return False except Exception as e: print(f"โŒ Test failed: {e}") import traceback traceback.print_exc() return False finally: # Cleanup if os.path.exists(temp_dir): shutil.rmtree(temp_dir) print("๐Ÿงน Cleaned up temporary files") if __name__ == "__main__": success = test_checkpoint_save() if success: print("๐ŸŽ‰ All tests passed! Checkpoint fixes are working.") sys.exit(0) else: print("๐Ÿ’ฅ Tests failed!") sys.exit(1)