From bace84f9e9dddffc4f1b987a004abdfc7488f757 Mon Sep 17 00:00:00 2001 From: ldy Date: Wed, 23 Jul 2025 15:17:14 +0800 Subject: [PATCH] Revised saving --- tests/test_installation.py | 7 ++++++- utils/checkpoint.py | 2 +- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/tests/test_installation.py b/tests/test_installation.py index db0b226..1e638ac 100644 --- a/tests/test_installation.py +++ b/tests/test_installation.py @@ -525,7 +525,12 @@ class ComprehensiveTestSuite: # Test saving state_dict = { - 'model_state_dict': torch.randn(10, 10), + 'model_state_dict': { + 'layer1.weight': torch.randn(10, 5), + 'layer1.bias': torch.randn(10), + 'layer2.weight': torch.randn(3, 10), + 'layer2.bias': torch.randn(3) + }, 'optimizer_state_dict': {'lr': 0.001}, 'step': 100 } diff --git a/utils/checkpoint.py b/utils/checkpoint.py index d38f1f0..f49f231 100644 --- a/utils/checkpoint.py +++ b/utils/checkpoint.py @@ -86,7 +86,7 @@ class CheckpointManager: model_path = os.path.join(temp_dir, 'pytorch_model.bin') # Add extra safety check for model state dict model_state = state_dict['model_state_dict'] - if not model_state: + if model_state is None or (isinstance(model_state, dict) and len(model_state) == 0): raise RuntimeError("Model state dict is empty") self._safe_torch_save(model_state, model_path)