Compare commits
No commits in common. "6dbd2f3281070453fd3c4178b9d33d4851f1d83e" and "36003b83e2d8f5fd115bf82a2e805dd86c5079a1" have entirely different histories.
6dbd2f3281
...
36003b83e2
@ -1,105 +0,0 @@
|
|||||||
#!/usr/bin/env python
|
|
||||||
"""Minimal test to verify checkpoint saving fixes"""
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import tempfile
|
|
||||||
import os
|
|
||||||
import shutil
|
|
||||||
from pathlib import Path
|
|
||||||
import sys
|
|
||||||
|
|
||||||
# Add parent directory to path
|
|
||||||
sys.path.insert(0, str(Path(__file__).parent))
|
|
||||||
|
|
||||||
from utils.checkpoint import CheckpointManager
|
|
||||||
|
|
||||||
|
|
||||||
def test_checkpoint_save():
|
|
||||||
"""Test basic checkpoint save functionality"""
|
|
||||||
print("🧪 Testing checkpoint save fixes...")
|
|
||||||
|
|
||||||
# Create temporary directory
|
|
||||||
temp_dir = tempfile.mkdtemp()
|
|
||||||
print(f"Using temp directory: {temp_dir}")
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Create checkpoint manager
|
|
||||||
manager = CheckpointManager(
|
|
||||||
output_dir=temp_dir,
|
|
||||||
save_total_limit=2
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create a simple model state dict
|
|
||||||
model_state = {
|
|
||||||
'layer1.weight': torch.randn(10, 5),
|
|
||||||
'layer1.bias': torch.randn(10),
|
|
||||||
'layer2.weight': torch.randn(3, 10),
|
|
||||||
'layer2.bias': torch.randn(3)
|
|
||||||
}
|
|
||||||
|
|
||||||
# Create optimizer state dict
|
|
||||||
optimizer_state = {
|
|
||||||
'param_groups': [{'lr': 0.001, 'momentum': 0.9}],
|
|
||||||
'state': {}
|
|
||||||
}
|
|
||||||
|
|
||||||
# Test save at step 1 (should work)
|
|
||||||
print("Testing checkpoint save at step 1...")
|
|
||||||
state_dict = {
|
|
||||||
'model_state_dict': model_state,
|
|
||||||
'optimizer_state_dict': optimizer_state,
|
|
||||||
'global_step': 1,
|
|
||||||
'current_epoch': 0
|
|
||||||
}
|
|
||||||
|
|
||||||
checkpoint_path = manager.save(state_dict, step=1)
|
|
||||||
print(f"✅ Successfully saved checkpoint at step 1: {checkpoint_path}")
|
|
||||||
|
|
||||||
# Test save at step 0 (should be avoided by trainer logic)
|
|
||||||
print("Testing checkpoint save at step 0...")
|
|
||||||
state_dict['global_step'] = 0
|
|
||||||
|
|
||||||
checkpoint_path = manager.save(state_dict, step=0)
|
|
||||||
print(f"✅ Successfully saved checkpoint at step 0: {checkpoint_path}")
|
|
||||||
print("Note: Step 0 save worked, but trainer logic should prevent this")
|
|
||||||
|
|
||||||
# Verify files exist
|
|
||||||
model_file = os.path.join(checkpoint_path, 'pytorch_model.bin')
|
|
||||||
if os.path.exists(model_file):
|
|
||||||
size = os.path.getsize(model_file) / 1024 # KB
|
|
||||||
print(f"✅ Model file exists and has size: {size:.1f}KB")
|
|
||||||
else:
|
|
||||||
print("❌ Model file not found")
|
|
||||||
return False
|
|
||||||
|
|
||||||
# Test loading
|
|
||||||
print("Testing checkpoint load...")
|
|
||||||
loaded_state = manager.load_checkpoint(checkpoint_path)
|
|
||||||
|
|
||||||
if 'model_state_dict' in loaded_state:
|
|
||||||
print("✅ Successfully loaded checkpoint")
|
|
||||||
return True
|
|
||||||
else:
|
|
||||||
print("❌ Failed to load model state")
|
|
||||||
return False
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"❌ Test failed: {e}")
|
|
||||||
import traceback
|
|
||||||
traceback.print_exc()
|
|
||||||
return False
|
|
||||||
finally:
|
|
||||||
# Cleanup
|
|
||||||
if os.path.exists(temp_dir):
|
|
||||||
shutil.rmtree(temp_dir)
|
|
||||||
print("🧹 Cleaned up temporary files")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
success = test_checkpoint_save()
|
|
||||||
if success:
|
|
||||||
print("🎉 All tests passed! Checkpoint fixes are working.")
|
|
||||||
sys.exit(0)
|
|
||||||
else:
|
|
||||||
print("💥 Tests failed!")
|
|
||||||
sys.exit(1)
|
|
||||||
@ -56,22 +56,19 @@ def main():
|
|||||||
model_name_or_path="./models/bge-m3",
|
model_name_or_path="./models/bge-m3",
|
||||||
train_data_path=train_file,
|
train_data_path=train_file,
|
||||||
output_dir="output/test-lr-scheduler",
|
output_dir="output/test-lr-scheduler",
|
||||||
num_train_epochs=3,
|
num_train_epochs=2,
|
||||||
per_device_train_batch_size=1, # Smaller to ensure more steps
|
per_device_train_batch_size=2,
|
||||||
gradient_accumulation_steps=2, # More accumulation steps for stability
|
gradient_accumulation_steps=1, # Small for testing
|
||||||
learning_rate=1e-5, # Lower learning rate for stability
|
learning_rate=1e-4, # Higher for testing
|
||||||
logging_steps=1,
|
logging_steps=1,
|
||||||
save_steps=0, # Disable frequent checkpointing during test
|
save_steps=10,
|
||||||
eval_steps=0, # Disable evaluation during test
|
fp16=False,
|
||||||
fp16=False, # Disable mixed precision for stability
|
|
||||||
bf16=False, # Disable bfloat16
|
|
||||||
use_amp=False,
|
use_amp=False,
|
||||||
dataloader_num_workers=0,
|
dataloader_num_workers=0,
|
||||||
overwrite_output_dir=True,
|
overwrite_output_dir=True,
|
||||||
use_self_distill=False,
|
use_self_distill=False,
|
||||||
use_hard_negatives=False,
|
use_hard_negatives=False,
|
||||||
temperature=0.1, # Fixed temperature
|
temperature=0.1 # Fixed temperature
|
||||||
save_total_limit=1 # Limit checkpoint storage
|
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|||||||
@ -184,11 +184,7 @@ class BaseTrainer(ABC):
|
|||||||
# Add to training history
|
# Add to training history
|
||||||
self.training_history.append(epoch_metrics)
|
self.training_history.append(epoch_metrics)
|
||||||
|
|
||||||
# Save checkpoint at end of epoch (if save_steps allows and we've done training)
|
if is_main_process() and (epoch + 1) % self.config.save_steps == 0:
|
||||||
if (is_main_process() and
|
|
||||||
self.config.save_steps > 0 and
|
|
||||||
self.global_step > 0 and
|
|
||||||
(epoch + 1) % self.config.save_steps == 0):
|
|
||||||
self._save_checkpoint(metrics=train_metrics)
|
self._save_checkpoint(metrics=train_metrics)
|
||||||
|
|
||||||
epoch_time = time.time() - epoch_start_time
|
epoch_time = time.time() - epoch_start_time
|
||||||
@ -197,8 +193,6 @@ class BaseTrainer(ABC):
|
|||||||
logger.info(f"Epoch {epoch + 1} metrics: {train_metrics}")
|
logger.info(f"Epoch {epoch + 1} metrics: {train_metrics}")
|
||||||
|
|
||||||
if is_main_process():
|
if is_main_process():
|
||||||
# Only save final checkpoint if we've actually trained (global_step > 0)
|
|
||||||
if self.global_step > 0:
|
|
||||||
self._save_checkpoint(metrics=train_metrics)
|
self._save_checkpoint(metrics=train_metrics)
|
||||||
self._save_training_summary()
|
self._save_training_summary()
|
||||||
# Close TensorBoard logger
|
# Close TensorBoard logger
|
||||||
@ -357,11 +351,8 @@ class BaseTrainer(ABC):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.debug(f"Logging failed: {e}")
|
logger.debug(f"Logging failed: {e}")
|
||||||
|
|
||||||
# Checkpoint saving with error handling - avoid saving at step 0
|
# Checkpoint saving with error handling
|
||||||
if (self.config.save_steps > 0 and
|
if self.config.save_steps > 0 and self.global_step % self.config.save_steps == 0 and is_main_process():
|
||||||
self.global_step > 0 and # Don't save at step 0
|
|
||||||
self.global_step % self.config.save_steps == 0 and
|
|
||||||
is_main_process()):
|
|
||||||
try:
|
try:
|
||||||
if self.checkpoint_manager:
|
if self.checkpoint_manager:
|
||||||
state_dict = {
|
state_dict = {
|
||||||
|
|||||||
@ -84,12 +84,7 @@ class CheckpointManager:
|
|||||||
# Save model with error handling
|
# Save model with error handling
|
||||||
if 'model_state_dict' in state_dict:
|
if 'model_state_dict' in state_dict:
|
||||||
model_path = os.path.join(temp_dir, 'pytorch_model.bin')
|
model_path = os.path.join(temp_dir, 'pytorch_model.bin')
|
||||||
# Add extra safety check for model state dict
|
self._safe_torch_save(state_dict['model_state_dict'], model_path)
|
||||||
model_state = state_dict['model_state_dict']
|
|
||||||
if not model_state:
|
|
||||||
raise RuntimeError("Model state dict is empty")
|
|
||||||
|
|
||||||
self._safe_torch_save(model_state, model_path)
|
|
||||||
checkpoint_files['model'] = 'pytorch_model.bin'
|
checkpoint_files['model'] = 'pytorch_model.bin'
|
||||||
|
|
||||||
# Save optimizer
|
# Save optimizer
|
||||||
@ -361,95 +356,16 @@ class CheckpointManager:
|
|||||||
return os.path.exists(checkpoint_dir)
|
return os.path.exists(checkpoint_dir)
|
||||||
|
|
||||||
def _safe_torch_save(self, obj: Any, path: str):
|
def _safe_torch_save(self, obj: Any, path: str):
|
||||||
"""Save torch object with robust error handling and fallback methods"""
|
"""Save torch object with error handling"""
|
||||||
temp_path = f"{path}.tmp"
|
temp_path = f"{path}.tmp"
|
||||||
|
|
||||||
# Check disk space before attempting save (require at least 1GB free)
|
|
||||||
try:
|
try:
|
||||||
import shutil
|
torch.save(obj, temp_path)
|
||||||
free_space = shutil.disk_usage(os.path.dirname(path)).free
|
|
||||||
if free_space < 1024 * 1024 * 1024: # 1GB
|
|
||||||
raise RuntimeError(f"Insufficient disk space: {free_space / (1024**3):.1f}GB available")
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Could not check disk space: {e}")
|
|
||||||
|
|
||||||
# Try multiple save methods with increasing robustness
|
|
||||||
save_methods = [
|
|
||||||
# Method 1: Standard PyTorch save with weights_only for security
|
|
||||||
lambda: torch.save(obj, temp_path, _use_new_zipfile_serialization=False),
|
|
||||||
# Method 2: Legacy PyTorch save without new zipfile
|
|
||||||
lambda: torch.save(obj, temp_path),
|
|
||||||
# Method 3: Save with pickle protocol 4 (more robust)
|
|
||||||
lambda: torch.save(obj, temp_path, pickle_protocol=4),
|
|
||||||
# Method 4: Force CPU and retry (in case of GPU memory issues)
|
|
||||||
lambda: self._cpu_fallback_save(obj, temp_path)
|
|
||||||
]
|
|
||||||
|
|
||||||
last_error = None
|
|
||||||
for i, save_method in enumerate(save_methods):
|
|
||||||
try:
|
|
||||||
# Clear GPU cache before saving to free up memory
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
|
|
||||||
save_method()
|
|
||||||
|
|
||||||
# Verify file was created and has reasonable size
|
|
||||||
if os.path.exists(temp_path):
|
|
||||||
size = os.path.getsize(temp_path)
|
|
||||||
if size < 100: # File too small, likely corrupted
|
|
||||||
raise RuntimeError(f"Saved file too small: {size} bytes")
|
|
||||||
|
|
||||||
# Atomic rename
|
# Atomic rename
|
||||||
os.replace(temp_path, path)
|
os.replace(temp_path, path)
|
||||||
logger.info(f"Successfully saved {path} using method {i+1} (size: {size/1024/1024:.1f}MB)")
|
|
||||||
return
|
|
||||||
else:
|
|
||||||
raise RuntimeError("Temp file was not created")
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
last_error = e
|
|
||||||
logger.warning(f"Save method {i+1} failed for {path}: {e}")
|
|
||||||
if os.path.exists(temp_path):
|
if os.path.exists(temp_path):
|
||||||
try:
|
|
||||||
os.remove(temp_path)
|
os.remove(temp_path)
|
||||||
except:
|
raise RuntimeError(f"Failed to save {path}: {e}")
|
||||||
pass
|
|
||||||
continue
|
|
||||||
|
|
||||||
# All methods failed
|
|
||||||
if os.path.exists(temp_path):
|
|
||||||
try:
|
|
||||||
os.remove(temp_path)
|
|
||||||
except:
|
|
||||||
pass
|
|
||||||
raise RuntimeError(f"Failed to save {path} after trying all methods: {last_error}")
|
|
||||||
|
|
||||||
def _cpu_fallback_save(self, obj: Any, path: str):
|
|
||||||
"""Fallback save method: move tensors to CPU and save"""
|
|
||||||
import copy
|
|
||||||
|
|
||||||
# Deep copy and move to CPU
|
|
||||||
if hasattr(obj, 'state_dict'):
|
|
||||||
# For models with state_dict method
|
|
||||||
cpu_obj = {k: v.cpu() if isinstance(v, torch.Tensor) else v
|
|
||||||
for k, v in obj.state_dict().items()}
|
|
||||||
elif isinstance(obj, dict):
|
|
||||||
# For state dict dictionaries
|
|
||||||
cpu_obj = {}
|
|
||||||
for k, v in obj.items():
|
|
||||||
if isinstance(v, torch.Tensor):
|
|
||||||
cpu_obj[k] = v.cpu()
|
|
||||||
elif isinstance(v, dict):
|
|
||||||
cpu_obj[k] = {k2: v2.cpu() if isinstance(v2, torch.Tensor) else v2
|
|
||||||
for k2, v2 in v.items()}
|
|
||||||
else:
|
|
||||||
cpu_obj[k] = v
|
|
||||||
else:
|
|
||||||
# For other objects
|
|
||||||
cpu_obj = obj
|
|
||||||
|
|
||||||
torch.save(cpu_obj, path, pickle_protocol=4)
|
|
||||||
|
|
||||||
def _safe_json_save(self, obj: Any, path: str):
|
def _safe_json_save(self, obj: Any, path: str):
|
||||||
"""Save JSON with atomic write"""
|
"""Save JSON with atomic write"""
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user