Revised init commit
This commit is contained in:
@@ -84,7 +84,12 @@ class CheckpointManager:
|
||||
# Save model with error handling
|
||||
if 'model_state_dict' in state_dict:
|
||||
model_path = os.path.join(temp_dir, 'pytorch_model.bin')
|
||||
self._safe_torch_save(state_dict['model_state_dict'], model_path)
|
||||
# Add extra safety check for model state dict
|
||||
model_state = state_dict['model_state_dict']
|
||||
if not model_state:
|
||||
raise RuntimeError("Model state dict is empty")
|
||||
|
||||
self._safe_torch_save(model_state, model_path)
|
||||
checkpoint_files['model'] = 'pytorch_model.bin'
|
||||
|
||||
# Save optimizer
|
||||
@@ -356,16 +361,95 @@ class CheckpointManager:
|
||||
return os.path.exists(checkpoint_dir)
|
||||
|
||||
def _safe_torch_save(self, obj: Any, path: str):
|
||||
"""Save torch object with error handling"""
|
||||
"""Save torch object with robust error handling and fallback methods"""
|
||||
temp_path = f"{path}.tmp"
|
||||
|
||||
# Check disk space before attempting save (require at least 1GB free)
|
||||
try:
|
||||
torch.save(obj, temp_path)
|
||||
# Atomic rename
|
||||
os.replace(temp_path, path)
|
||||
import shutil
|
||||
free_space = shutil.disk_usage(os.path.dirname(path)).free
|
||||
if free_space < 1024 * 1024 * 1024: # 1GB
|
||||
raise RuntimeError(f"Insufficient disk space: {free_space / (1024**3):.1f}GB available")
|
||||
except Exception as e:
|
||||
if os.path.exists(temp_path):
|
||||
logger.warning(f"Could not check disk space: {e}")
|
||||
|
||||
# Try multiple save methods with increasing robustness
|
||||
save_methods = [
|
||||
# Method 1: Standard PyTorch save with weights_only for security
|
||||
lambda: torch.save(obj, temp_path, _use_new_zipfile_serialization=False),
|
||||
# Method 2: Legacy PyTorch save without new zipfile
|
||||
lambda: torch.save(obj, temp_path),
|
||||
# Method 3: Save with pickle protocol 4 (more robust)
|
||||
lambda: torch.save(obj, temp_path, pickle_protocol=4),
|
||||
# Method 4: Force CPU and retry (in case of GPU memory issues)
|
||||
lambda: self._cpu_fallback_save(obj, temp_path)
|
||||
]
|
||||
|
||||
last_error = None
|
||||
for i, save_method in enumerate(save_methods):
|
||||
try:
|
||||
# Clear GPU cache before saving to free up memory
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
save_method()
|
||||
|
||||
# Verify file was created and has reasonable size
|
||||
if os.path.exists(temp_path):
|
||||
size = os.path.getsize(temp_path)
|
||||
if size < 100: # File too small, likely corrupted
|
||||
raise RuntimeError(f"Saved file too small: {size} bytes")
|
||||
|
||||
# Atomic rename
|
||||
os.replace(temp_path, path)
|
||||
logger.info(f"Successfully saved {path} using method {i+1} (size: {size/1024/1024:.1f}MB)")
|
||||
return
|
||||
else:
|
||||
raise RuntimeError("Temp file was not created")
|
||||
|
||||
except Exception as e:
|
||||
last_error = e
|
||||
logger.warning(f"Save method {i+1} failed for {path}: {e}")
|
||||
if os.path.exists(temp_path):
|
||||
try:
|
||||
os.remove(temp_path)
|
||||
except:
|
||||
pass
|
||||
continue
|
||||
|
||||
# All methods failed
|
||||
if os.path.exists(temp_path):
|
||||
try:
|
||||
os.remove(temp_path)
|
||||
raise RuntimeError(f"Failed to save {path}: {e}")
|
||||
except:
|
||||
pass
|
||||
raise RuntimeError(f"Failed to save {path} after trying all methods: {last_error}")
|
||||
|
||||
def _cpu_fallback_save(self, obj: Any, path: str):
|
||||
"""Fallback save method: move tensors to CPU and save"""
|
||||
import copy
|
||||
|
||||
# Deep copy and move to CPU
|
||||
if hasattr(obj, 'state_dict'):
|
||||
# For models with state_dict method
|
||||
cpu_obj = {k: v.cpu() if isinstance(v, torch.Tensor) else v
|
||||
for k, v in obj.state_dict().items()}
|
||||
elif isinstance(obj, dict):
|
||||
# For state dict dictionaries
|
||||
cpu_obj = {}
|
||||
for k, v in obj.items():
|
||||
if isinstance(v, torch.Tensor):
|
||||
cpu_obj[k] = v.cpu()
|
||||
elif isinstance(v, dict):
|
||||
cpu_obj[k] = {k2: v2.cpu() if isinstance(v2, torch.Tensor) else v2
|
||||
for k2, v2 in v.items()}
|
||||
else:
|
||||
cpu_obj[k] = v
|
||||
else:
|
||||
# For other objects
|
||||
cpu_obj = obj
|
||||
|
||||
torch.save(cpu_obj, path, pickle_protocol=4)
|
||||
|
||||
def _safe_json_save(self, obj: Any, path: str):
|
||||
"""Save JSON with atomic write"""
|
||||
|
||||
Reference in New Issue
Block a user