Revised init commit

This commit is contained in:
ldy
2025-07-22 17:49:16 +08:00
parent 36003b83e2
commit ae0f405a9b
4 changed files with 219 additions and 18 deletions

View File

@@ -84,7 +84,12 @@ class CheckpointManager:
# Save model with error handling
if 'model_state_dict' in state_dict:
model_path = os.path.join(temp_dir, 'pytorch_model.bin')
self._safe_torch_save(state_dict['model_state_dict'], model_path)
# Add extra safety check for model state dict
model_state = state_dict['model_state_dict']
if not model_state:
raise RuntimeError("Model state dict is empty")
self._safe_torch_save(model_state, model_path)
checkpoint_files['model'] = 'pytorch_model.bin'
# Save optimizer
@@ -356,16 +361,95 @@ class CheckpointManager:
return os.path.exists(checkpoint_dir)
def _safe_torch_save(self, obj: Any, path: str):
"""Save torch object with error handling"""
"""Save torch object with robust error handling and fallback methods"""
temp_path = f"{path}.tmp"
# Check disk space before attempting save (require at least 1GB free)
try:
torch.save(obj, temp_path)
# Atomic rename
os.replace(temp_path, path)
import shutil
free_space = shutil.disk_usage(os.path.dirname(path)).free
if free_space < 1024 * 1024 * 1024: # 1GB
raise RuntimeError(f"Insufficient disk space: {free_space / (1024**3):.1f}GB available")
except Exception as e:
if os.path.exists(temp_path):
logger.warning(f"Could not check disk space: {e}")
# Try multiple save methods with increasing robustness
save_methods = [
# Method 1: Standard PyTorch save with weights_only for security
lambda: torch.save(obj, temp_path, _use_new_zipfile_serialization=False),
# Method 2: Legacy PyTorch save without new zipfile
lambda: torch.save(obj, temp_path),
# Method 3: Save with pickle protocol 4 (more robust)
lambda: torch.save(obj, temp_path, pickle_protocol=4),
# Method 4: Force CPU and retry (in case of GPU memory issues)
lambda: self._cpu_fallback_save(obj, temp_path)
]
last_error = None
for i, save_method in enumerate(save_methods):
try:
# Clear GPU cache before saving to free up memory
if torch.cuda.is_available():
torch.cuda.empty_cache()
save_method()
# Verify file was created and has reasonable size
if os.path.exists(temp_path):
size = os.path.getsize(temp_path)
if size < 100: # File too small, likely corrupted
raise RuntimeError(f"Saved file too small: {size} bytes")
# Atomic rename
os.replace(temp_path, path)
logger.info(f"Successfully saved {path} using method {i+1} (size: {size/1024/1024:.1f}MB)")
return
else:
raise RuntimeError("Temp file was not created")
except Exception as e:
last_error = e
logger.warning(f"Save method {i+1} failed for {path}: {e}")
if os.path.exists(temp_path):
try:
os.remove(temp_path)
except:
pass
continue
# All methods failed
if os.path.exists(temp_path):
try:
os.remove(temp_path)
raise RuntimeError(f"Failed to save {path}: {e}")
except:
pass
raise RuntimeError(f"Failed to save {path} after trying all methods: {last_error}")
def _cpu_fallback_save(self, obj: Any, path: str):
"""Fallback save method: move tensors to CPU and save"""
import copy
# Deep copy and move to CPU
if hasattr(obj, 'state_dict'):
# For models with state_dict method
cpu_obj = {k: v.cpu() if isinstance(v, torch.Tensor) else v
for k, v in obj.state_dict().items()}
elif isinstance(obj, dict):
# For state dict dictionaries
cpu_obj = {}
for k, v in obj.items():
if isinstance(v, torch.Tensor):
cpu_obj[k] = v.cpu()
elif isinstance(v, dict):
cpu_obj[k] = {k2: v2.cpu() if isinstance(v2, torch.Tensor) else v2
for k2, v2 in v.items()}
else:
cpu_obj[k] = v
else:
# For other objects
cpu_obj = obj
torch.save(cpu_obj, path, pickle_protocol=4)
def _safe_json_save(self, obj: Any, path: str):
"""Save JSON with atomic write"""