Compare commits

...

2 Commits

Author SHA1 Message Date
ldy
6dbd2f3281 Merge pull request 'Revised init commit' (#1) from fix-main into main
Reviewed-on: #1
2025-07-22 09:53:02 +00:00
ldy
ae0f405a9b Revised init commit 2025-07-22 17:49:16 +08:00
4 changed files with 219 additions and 18 deletions

105
test_checkpoint_fix.py Normal file
View File

@ -0,0 +1,105 @@
#!/usr/bin/env python
"""Minimal test to verify checkpoint saving fixes"""
import torch
import tempfile
import os
import shutil
from pathlib import Path
import sys
# Add parent directory to path
sys.path.insert(0, str(Path(__file__).parent))
from utils.checkpoint import CheckpointManager
def test_checkpoint_save():
"""Test basic checkpoint save functionality"""
print("🧪 Testing checkpoint save fixes...")
# Create temporary directory
temp_dir = tempfile.mkdtemp()
print(f"Using temp directory: {temp_dir}")
try:
# Create checkpoint manager
manager = CheckpointManager(
output_dir=temp_dir,
save_total_limit=2
)
# Create a simple model state dict
model_state = {
'layer1.weight': torch.randn(10, 5),
'layer1.bias': torch.randn(10),
'layer2.weight': torch.randn(3, 10),
'layer2.bias': torch.randn(3)
}
# Create optimizer state dict
optimizer_state = {
'param_groups': [{'lr': 0.001, 'momentum': 0.9}],
'state': {}
}
# Test save at step 1 (should work)
print("Testing checkpoint save at step 1...")
state_dict = {
'model_state_dict': model_state,
'optimizer_state_dict': optimizer_state,
'global_step': 1,
'current_epoch': 0
}
checkpoint_path = manager.save(state_dict, step=1)
print(f"✅ Successfully saved checkpoint at step 1: {checkpoint_path}")
# Test save at step 0 (should be avoided by trainer logic)
print("Testing checkpoint save at step 0...")
state_dict['global_step'] = 0
checkpoint_path = manager.save(state_dict, step=0)
print(f"✅ Successfully saved checkpoint at step 0: {checkpoint_path}")
print("Note: Step 0 save worked, but trainer logic should prevent this")
# Verify files exist
model_file = os.path.join(checkpoint_path, 'pytorch_model.bin')
if os.path.exists(model_file):
size = os.path.getsize(model_file) / 1024 # KB
print(f"✅ Model file exists and has size: {size:.1f}KB")
else:
print("❌ Model file not found")
return False
# Test loading
print("Testing checkpoint load...")
loaded_state = manager.load_checkpoint(checkpoint_path)
if 'model_state_dict' in loaded_state:
print("✅ Successfully loaded checkpoint")
return True
else:
print("❌ Failed to load model state")
return False
except Exception as e:
print(f"❌ Test failed: {e}")
import traceback
traceback.print_exc()
return False
finally:
# Cleanup
if os.path.exists(temp_dir):
shutil.rmtree(temp_dir)
print("🧹 Cleaned up temporary files")
if __name__ == "__main__":
success = test_checkpoint_save()
if success:
print("🎉 All tests passed! Checkpoint fixes are working.")
sys.exit(0)
else:
print("💥 Tests failed!")
sys.exit(1)

View File

@ -56,19 +56,22 @@ def main():
model_name_or_path="./models/bge-m3", model_name_or_path="./models/bge-m3",
train_data_path=train_file, train_data_path=train_file,
output_dir="output/test-lr-scheduler", output_dir="output/test-lr-scheduler",
num_train_epochs=2, num_train_epochs=3,
per_device_train_batch_size=2, per_device_train_batch_size=1, # Smaller to ensure more steps
gradient_accumulation_steps=1, # Small for testing gradient_accumulation_steps=2, # More accumulation steps for stability
learning_rate=1e-4, # Higher for testing learning_rate=1e-5, # Lower learning rate for stability
logging_steps=1, logging_steps=1,
save_steps=10, save_steps=0, # Disable frequent checkpointing during test
fp16=False, eval_steps=0, # Disable evaluation during test
fp16=False, # Disable mixed precision for stability
bf16=False, # Disable bfloat16
use_amp=False, use_amp=False,
dataloader_num_workers=0, dataloader_num_workers=0,
overwrite_output_dir=True, overwrite_output_dir=True,
use_self_distill=False, use_self_distill=False,
use_hard_negatives=False, use_hard_negatives=False,
temperature=0.1 # Fixed temperature temperature=0.1, # Fixed temperature
save_total_limit=1 # Limit checkpoint storage
) )
try: try:

View File

@ -184,7 +184,11 @@ class BaseTrainer(ABC):
# Add to training history # Add to training history
self.training_history.append(epoch_metrics) self.training_history.append(epoch_metrics)
if is_main_process() and (epoch + 1) % self.config.save_steps == 0: # Save checkpoint at end of epoch (if save_steps allows and we've done training)
if (is_main_process() and
self.config.save_steps > 0 and
self.global_step > 0 and
(epoch + 1) % self.config.save_steps == 0):
self._save_checkpoint(metrics=train_metrics) self._save_checkpoint(metrics=train_metrics)
epoch_time = time.time() - epoch_start_time epoch_time = time.time() - epoch_start_time
@ -193,6 +197,8 @@ class BaseTrainer(ABC):
logger.info(f"Epoch {epoch + 1} metrics: {train_metrics}") logger.info(f"Epoch {epoch + 1} metrics: {train_metrics}")
if is_main_process(): if is_main_process():
# Only save final checkpoint if we've actually trained (global_step > 0)
if self.global_step > 0:
self._save_checkpoint(metrics=train_metrics) self._save_checkpoint(metrics=train_metrics)
self._save_training_summary() self._save_training_summary()
# Close TensorBoard logger # Close TensorBoard logger
@ -351,8 +357,11 @@ class BaseTrainer(ABC):
except Exception as e: except Exception as e:
logger.debug(f"Logging failed: {e}") logger.debug(f"Logging failed: {e}")
# Checkpoint saving with error handling # Checkpoint saving with error handling - avoid saving at step 0
if self.config.save_steps > 0 and self.global_step % self.config.save_steps == 0 and is_main_process(): if (self.config.save_steps > 0 and
self.global_step > 0 and # Don't save at step 0
self.global_step % self.config.save_steps == 0 and
is_main_process()):
try: try:
if self.checkpoint_manager: if self.checkpoint_manager:
state_dict = { state_dict = {

View File

@ -84,7 +84,12 @@ class CheckpointManager:
# Save model with error handling # Save model with error handling
if 'model_state_dict' in state_dict: if 'model_state_dict' in state_dict:
model_path = os.path.join(temp_dir, 'pytorch_model.bin') model_path = os.path.join(temp_dir, 'pytorch_model.bin')
self._safe_torch_save(state_dict['model_state_dict'], model_path) # Add extra safety check for model state dict
model_state = state_dict['model_state_dict']
if not model_state:
raise RuntimeError("Model state dict is empty")
self._safe_torch_save(model_state, model_path)
checkpoint_files['model'] = 'pytorch_model.bin' checkpoint_files['model'] = 'pytorch_model.bin'
# Save optimizer # Save optimizer
@ -356,16 +361,95 @@ class CheckpointManager:
return os.path.exists(checkpoint_dir) return os.path.exists(checkpoint_dir)
def _safe_torch_save(self, obj: Any, path: str): def _safe_torch_save(self, obj: Any, path: str):
"""Save torch object with error handling""" """Save torch object with robust error handling and fallback methods"""
temp_path = f"{path}.tmp" temp_path = f"{path}.tmp"
# Check disk space before attempting save (require at least 1GB free)
try: try:
torch.save(obj, temp_path) import shutil
free_space = shutil.disk_usage(os.path.dirname(path)).free
if free_space < 1024 * 1024 * 1024: # 1GB
raise RuntimeError(f"Insufficient disk space: {free_space / (1024**3):.1f}GB available")
except Exception as e:
logger.warning(f"Could not check disk space: {e}")
# Try multiple save methods with increasing robustness
save_methods = [
# Method 1: Standard PyTorch save with weights_only for security
lambda: torch.save(obj, temp_path, _use_new_zipfile_serialization=False),
# Method 2: Legacy PyTorch save without new zipfile
lambda: torch.save(obj, temp_path),
# Method 3: Save with pickle protocol 4 (more robust)
lambda: torch.save(obj, temp_path, pickle_protocol=4),
# Method 4: Force CPU and retry (in case of GPU memory issues)
lambda: self._cpu_fallback_save(obj, temp_path)
]
last_error = None
for i, save_method in enumerate(save_methods):
try:
# Clear GPU cache before saving to free up memory
if torch.cuda.is_available():
torch.cuda.empty_cache()
save_method()
# Verify file was created and has reasonable size
if os.path.exists(temp_path):
size = os.path.getsize(temp_path)
if size < 100: # File too small, likely corrupted
raise RuntimeError(f"Saved file too small: {size} bytes")
# Atomic rename # Atomic rename
os.replace(temp_path, path) os.replace(temp_path, path)
logger.info(f"Successfully saved {path} using method {i+1} (size: {size/1024/1024:.1f}MB)")
return
else:
raise RuntimeError("Temp file was not created")
except Exception as e: except Exception as e:
last_error = e
logger.warning(f"Save method {i+1} failed for {path}: {e}")
if os.path.exists(temp_path): if os.path.exists(temp_path):
try:
os.remove(temp_path) os.remove(temp_path)
raise RuntimeError(f"Failed to save {path}: {e}") except:
pass
continue
# All methods failed
if os.path.exists(temp_path):
try:
os.remove(temp_path)
except:
pass
raise RuntimeError(f"Failed to save {path} after trying all methods: {last_error}")
def _cpu_fallback_save(self, obj: Any, path: str):
"""Fallback save method: move tensors to CPU and save"""
import copy
# Deep copy and move to CPU
if hasattr(obj, 'state_dict'):
# For models with state_dict method
cpu_obj = {k: v.cpu() if isinstance(v, torch.Tensor) else v
for k, v in obj.state_dict().items()}
elif isinstance(obj, dict):
# For state dict dictionaries
cpu_obj = {}
for k, v in obj.items():
if isinstance(v, torch.Tensor):
cpu_obj[k] = v.cpu()
elif isinstance(v, dict):
cpu_obj[k] = {k2: v2.cpu() if isinstance(v2, torch.Tensor) else v2
for k2, v2 in v.items()}
else:
cpu_obj[k] = v
else:
# For other objects
cpu_obj = obj
torch.save(cpu_obj, path, pickle_protocol=4)
def _safe_json_save(self, obj: Any, path: str): def _safe_json_save(self, obj: Any, path: str):
"""Save JSON with atomic write""" """Save JSON with atomic write"""