From ae0f405a9bd0c9afb088834577804f6991129d77 Mon Sep 17 00:00:00 2001 From: ldy Date: Tue, 22 Jul 2025 17:49:16 +0800 Subject: [PATCH] Revised init commit --- test_checkpoint_fix.py | 105 ++++++++++++++++++++++++++++++++++++ tests/test_full_training.py | 17 +++--- training/trainer.py | 17 ++++-- utils/checkpoint.py | 98 ++++++++++++++++++++++++++++++--- 4 files changed, 219 insertions(+), 18 deletions(-) create mode 100644 test_checkpoint_fix.py diff --git a/test_checkpoint_fix.py b/test_checkpoint_fix.py new file mode 100644 index 0000000..32b7f1e --- /dev/null +++ b/test_checkpoint_fix.py @@ -0,0 +1,105 @@ +#!/usr/bin/env python +"""Minimal test to verify checkpoint saving fixes""" + +import torch +import tempfile +import os +import shutil +from pathlib import Path +import sys + +# Add parent directory to path +sys.path.insert(0, str(Path(__file__).parent)) + +from utils.checkpoint import CheckpointManager + + +def test_checkpoint_save(): + """Test basic checkpoint save functionality""" + print("๐Ÿงช Testing checkpoint save fixes...") + + # Create temporary directory + temp_dir = tempfile.mkdtemp() + print(f"Using temp directory: {temp_dir}") + + try: + # Create checkpoint manager + manager = CheckpointManager( + output_dir=temp_dir, + save_total_limit=2 + ) + + # Create a simple model state dict + model_state = { + 'layer1.weight': torch.randn(10, 5), + 'layer1.bias': torch.randn(10), + 'layer2.weight': torch.randn(3, 10), + 'layer2.bias': torch.randn(3) + } + + # Create optimizer state dict + optimizer_state = { + 'param_groups': [{'lr': 0.001, 'momentum': 0.9}], + 'state': {} + } + + # Test save at step 1 (should work) + print("Testing checkpoint save at step 1...") + state_dict = { + 'model_state_dict': model_state, + 'optimizer_state_dict': optimizer_state, + 'global_step': 1, + 'current_epoch': 0 + } + + checkpoint_path = manager.save(state_dict, step=1) + print(f"โœ… Successfully saved checkpoint at step 1: {checkpoint_path}") + + # Test save at step 0 (should be avoided by trainer logic) + print("Testing checkpoint save at step 0...") + state_dict['global_step'] = 0 + + checkpoint_path = manager.save(state_dict, step=0) + print(f"โœ… Successfully saved checkpoint at step 0: {checkpoint_path}") + print("Note: Step 0 save worked, but trainer logic should prevent this") + + # Verify files exist + model_file = os.path.join(checkpoint_path, 'pytorch_model.bin') + if os.path.exists(model_file): + size = os.path.getsize(model_file) / 1024 # KB + print(f"โœ… Model file exists and has size: {size:.1f}KB") + else: + print("โŒ Model file not found") + return False + + # Test loading + print("Testing checkpoint load...") + loaded_state = manager.load_checkpoint(checkpoint_path) + + if 'model_state_dict' in loaded_state: + print("โœ… Successfully loaded checkpoint") + return True + else: + print("โŒ Failed to load model state") + return False + + except Exception as e: + print(f"โŒ Test failed: {e}") + import traceback + traceback.print_exc() + return False + finally: + # Cleanup + if os.path.exists(temp_dir): + shutil.rmtree(temp_dir) + print("๐Ÿงน Cleaned up temporary files") + + +if __name__ == "__main__": + success = test_checkpoint_save() + if success: + print("๐ŸŽ‰ All tests passed! Checkpoint fixes are working.") + sys.exit(0) + else: + print("๐Ÿ’ฅ Tests failed!") + sys.exit(1) \ No newline at end of file diff --git a/tests/test_full_training.py b/tests/test_full_training.py index faf9474..c07c07e 100644 --- a/tests/test_full_training.py +++ b/tests/test_full_training.py @@ -56,19 +56,22 @@ def main(): model_name_or_path="./models/bge-m3", train_data_path=train_file, output_dir="output/test-lr-scheduler", - num_train_epochs=2, - per_device_train_batch_size=2, - gradient_accumulation_steps=1, # Small for testing - learning_rate=1e-4, # Higher for testing + num_train_epochs=3, + per_device_train_batch_size=1, # Smaller to ensure more steps + gradient_accumulation_steps=2, # More accumulation steps for stability + learning_rate=1e-5, # Lower learning rate for stability logging_steps=1, - save_steps=10, - fp16=False, + save_steps=0, # Disable frequent checkpointing during test + eval_steps=0, # Disable evaluation during test + fp16=False, # Disable mixed precision for stability + bf16=False, # Disable bfloat16 use_amp=False, dataloader_num_workers=0, overwrite_output_dir=True, use_self_distill=False, use_hard_negatives=False, - temperature=0.1 # Fixed temperature + temperature=0.1, # Fixed temperature + save_total_limit=1 # Limit checkpoint storage ) try: diff --git a/training/trainer.py b/training/trainer.py index fb5af84..6326843 100644 --- a/training/trainer.py +++ b/training/trainer.py @@ -184,7 +184,11 @@ class BaseTrainer(ABC): # Add to training history self.training_history.append(epoch_metrics) - if is_main_process() and (epoch + 1) % self.config.save_steps == 0: + # Save checkpoint at end of epoch (if save_steps allows and we've done training) + if (is_main_process() and + self.config.save_steps > 0 and + self.global_step > 0 and + (epoch + 1) % self.config.save_steps == 0): self._save_checkpoint(metrics=train_metrics) epoch_time = time.time() - epoch_start_time @@ -193,7 +197,9 @@ class BaseTrainer(ABC): logger.info(f"Epoch {epoch + 1} metrics: {train_metrics}") if is_main_process(): - self._save_checkpoint(metrics=train_metrics) + # Only save final checkpoint if we've actually trained (global_step > 0) + if self.global_step > 0: + self._save_checkpoint(metrics=train_metrics) self._save_training_summary() # Close TensorBoard logger self.tb_logger.close() @@ -351,8 +357,11 @@ class BaseTrainer(ABC): except Exception as e: logger.debug(f"Logging failed: {e}") - # Checkpoint saving with error handling - if self.config.save_steps > 0 and self.global_step % self.config.save_steps == 0 and is_main_process(): + # Checkpoint saving with error handling - avoid saving at step 0 + if (self.config.save_steps > 0 and + self.global_step > 0 and # Don't save at step 0 + self.global_step % self.config.save_steps == 0 and + is_main_process()): try: if self.checkpoint_manager: state_dict = { diff --git a/utils/checkpoint.py b/utils/checkpoint.py index 5b31c5e..d38f1f0 100644 --- a/utils/checkpoint.py +++ b/utils/checkpoint.py @@ -84,7 +84,12 @@ class CheckpointManager: # Save model with error handling if 'model_state_dict' in state_dict: model_path = os.path.join(temp_dir, 'pytorch_model.bin') - self._safe_torch_save(state_dict['model_state_dict'], model_path) + # Add extra safety check for model state dict + model_state = state_dict['model_state_dict'] + if not model_state: + raise RuntimeError("Model state dict is empty") + + self._safe_torch_save(model_state, model_path) checkpoint_files['model'] = 'pytorch_model.bin' # Save optimizer @@ -356,16 +361,95 @@ class CheckpointManager: return os.path.exists(checkpoint_dir) def _safe_torch_save(self, obj: Any, path: str): - """Save torch object with error handling""" + """Save torch object with robust error handling and fallback methods""" temp_path = f"{path}.tmp" + + # Check disk space before attempting save (require at least 1GB free) try: - torch.save(obj, temp_path) - # Atomic rename - os.replace(temp_path, path) + import shutil + free_space = shutil.disk_usage(os.path.dirname(path)).free + if free_space < 1024 * 1024 * 1024: # 1GB + raise RuntimeError(f"Insufficient disk space: {free_space / (1024**3):.1f}GB available") except Exception as e: - if os.path.exists(temp_path): + logger.warning(f"Could not check disk space: {e}") + + # Try multiple save methods with increasing robustness + save_methods = [ + # Method 1: Standard PyTorch save with weights_only for security + lambda: torch.save(obj, temp_path, _use_new_zipfile_serialization=False), + # Method 2: Legacy PyTorch save without new zipfile + lambda: torch.save(obj, temp_path), + # Method 3: Save with pickle protocol 4 (more robust) + lambda: torch.save(obj, temp_path, pickle_protocol=4), + # Method 4: Force CPU and retry (in case of GPU memory issues) + lambda: self._cpu_fallback_save(obj, temp_path) + ] + + last_error = None + for i, save_method in enumerate(save_methods): + try: + # Clear GPU cache before saving to free up memory + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + save_method() + + # Verify file was created and has reasonable size + if os.path.exists(temp_path): + size = os.path.getsize(temp_path) + if size < 100: # File too small, likely corrupted + raise RuntimeError(f"Saved file too small: {size} bytes") + + # Atomic rename + os.replace(temp_path, path) + logger.info(f"Successfully saved {path} using method {i+1} (size: {size/1024/1024:.1f}MB)") + return + else: + raise RuntimeError("Temp file was not created") + + except Exception as e: + last_error = e + logger.warning(f"Save method {i+1} failed for {path}: {e}") + if os.path.exists(temp_path): + try: + os.remove(temp_path) + except: + pass + continue + + # All methods failed + if os.path.exists(temp_path): + try: os.remove(temp_path) - raise RuntimeError(f"Failed to save {path}: {e}") + except: + pass + raise RuntimeError(f"Failed to save {path} after trying all methods: {last_error}") + + def _cpu_fallback_save(self, obj: Any, path: str): + """Fallback save method: move tensors to CPU and save""" + import copy + + # Deep copy and move to CPU + if hasattr(obj, 'state_dict'): + # For models with state_dict method + cpu_obj = {k: v.cpu() if isinstance(v, torch.Tensor) else v + for k, v in obj.state_dict().items()} + elif isinstance(obj, dict): + # For state dict dictionaries + cpu_obj = {} + for k, v in obj.items(): + if isinstance(v, torch.Tensor): + cpu_obj[k] = v.cpu() + elif isinstance(v, dict): + cpu_obj[k] = {k2: v2.cpu() if isinstance(v2, torch.Tensor) else v2 + for k2, v2 in v.items()} + else: + cpu_obj[k] = v + else: + # For other objects + cpu_obj = obj + + torch.save(cpu_obj, path, pickle_protocol=4) def _safe_json_save(self, obj: Any, path: str): """Save JSON with atomic write""" -- 2.47.2