init commit
This commit is contained in:
127
utils/__init__.py
Normal file
127
utils/__init__.py
Normal file
@@ -0,0 +1,127 @@
|
||||
"""
|
||||
Utility modules for logging, checkpointing, and distributed training
|
||||
"""
|
||||
from .logging import (
|
||||
setup_logging,
|
||||
get_logger,
|
||||
MetricsLogger,
|
||||
TensorBoardLogger,
|
||||
ProgressLogger,
|
||||
log_model_info,
|
||||
log_training_summary
|
||||
)
|
||||
from .checkpoint import CheckpointManager, save_model_for_inference
|
||||
from .distributed import (
|
||||
setup_distributed,
|
||||
cleanup_distributed,
|
||||
is_main_process,
|
||||
get_rank,
|
||||
get_world_size,
|
||||
get_local_rank,
|
||||
synchronize,
|
||||
all_reduce,
|
||||
all_gather,
|
||||
broadcast,
|
||||
reduce_dict,
|
||||
setup_model_for_distributed,
|
||||
create_distributed_dataloader,
|
||||
set_random_seed,
|
||||
print_distributed_info,
|
||||
DistributedMetricAggregator,
|
||||
save_on_master,
|
||||
log_on_master
|
||||
)
|
||||
import os
|
||||
import sys
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def optional_import(module_name, feature_desc=None, warn_only=True):
|
||||
"""Try to import an optional module. Log a warning if not found. Returns the module or None."""
|
||||
import importlib
|
||||
try:
|
||||
return importlib.import_module(module_name)
|
||||
except ImportError:
|
||||
msg = f"Optional dependency '{module_name}' not found."
|
||||
if feature_desc:
|
||||
msg += f" {feature_desc} will be unavailable."
|
||||
if warn_only:
|
||||
logger.warning(msg)
|
||||
else:
|
||||
logger.error(msg)
|
||||
return None
|
||||
|
||||
# Import guard for optional dependencies
|
||||
|
||||
def _check_optional_dependencies():
|
||||
"""Check for optional dependencies and log warnings if missing when features are likely to be used."""
|
||||
# Check for torch_npu if running on Ascend
|
||||
torch = optional_import('torch', warn_only=True)
|
||||
if torch and ('npu' in str(getattr(torch, 'device', lambda: '')()).lower() or os.environ.get('DEVICE', '').lower() == 'npu'):
|
||||
optional_import('torch_npu', feature_desc='Ascend NPU support')
|
||||
|
||||
# Check for wandb if Weights & Biases logging is enabled
|
||||
try:
|
||||
from utils.config_loader import get_config
|
||||
config = get_config()
|
||||
use_wandb = False
|
||||
try:
|
||||
use_wandb = config.get('logging.use_wandb', False) # type: ignore[arg-type]
|
||||
except Exception:
|
||||
pass
|
||||
if use_wandb:
|
||||
optional_import('wandb', feature_desc='Weights & Biases logging')
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Check for tensorboard if tensorboard logging is likely to be used
|
||||
tensorboard_dir = None
|
||||
try:
|
||||
tensorboard_dir = config.get('logging.tensorboard_dir', None) # type: ignore[arg-type]
|
||||
except Exception:
|
||||
pass
|
||||
if tensorboard_dir:
|
||||
optional_import('tensorboard', feature_desc='TensorBoard logging')
|
||||
|
||||
# Check for pytest if running tests
|
||||
import sys
|
||||
if any('pytest' in arg for arg in sys.argv):
|
||||
optional_import('pytest', feature_desc='Some tests may not run')
|
||||
|
||||
|
||||
__all__ = [
|
||||
# Logging
|
||||
'setup_logging',
|
||||
'get_logger',
|
||||
'MetricsLogger',
|
||||
'TensorBoardLogger',
|
||||
'ProgressLogger',
|
||||
'log_model_info',
|
||||
'log_training_summary',
|
||||
|
||||
# Checkpointing
|
||||
'CheckpointManager',
|
||||
'save_model_for_inference',
|
||||
|
||||
# Distributed training
|
||||
'setup_distributed',
|
||||
'cleanup_distributed',
|
||||
'is_main_process',
|
||||
'get_rank',
|
||||
'get_world_size',
|
||||
'get_local_rank',
|
||||
'synchronize',
|
||||
'all_reduce',
|
||||
'all_gather',
|
||||
'broadcast',
|
||||
'reduce_dict',
|
||||
'setup_model_for_distributed',
|
||||
'create_distributed_dataloader',
|
||||
'set_random_seed',
|
||||
'print_distributed_info',
|
||||
'DistributedMetricAggregator',
|
||||
'save_on_master',
|
||||
'log_on_master'
|
||||
]
|
||||
271
utils/ascend_support.py
Normal file
271
utils/ascend_support.py
Normal file
@@ -0,0 +1,271 @@
|
||||
"""
|
||||
Huawei Ascend NPU support utilities for BGE fine-tuning
|
||||
|
||||
This module provides compatibility layer for running BGE models on Ascend NPUs.
|
||||
Requires torch_npu to be installed: pip install torch_npu
|
||||
"""
|
||||
import os
|
||||
import logging
|
||||
from typing import Optional, Dict, Any, Tuple
|
||||
import torch
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Global flag to track Ascend availability
|
||||
_ASCEND_AVAILABLE = False
|
||||
_ASCEND_CHECK_DONE = False
|
||||
|
||||
|
||||
def check_ascend_availability() -> bool:
|
||||
"""
|
||||
Check if Ascend NPU is available
|
||||
|
||||
Returns:
|
||||
bool: True if Ascend NPU is available and properly configured
|
||||
"""
|
||||
global _ASCEND_AVAILABLE, _ASCEND_CHECK_DONE
|
||||
|
||||
if _ASCEND_CHECK_DONE:
|
||||
return _ASCEND_AVAILABLE
|
||||
|
||||
_ASCEND_CHECK_DONE = True
|
||||
|
||||
try:
|
||||
import torch_npu
|
||||
|
||||
# Check if NPU is available
|
||||
if hasattr(torch, 'npu') and torch.npu.is_available():
|
||||
_ASCEND_AVAILABLE = True
|
||||
npu_count = torch.npu.device_count()
|
||||
logger.info(f"Ascend NPU available with {npu_count} devices")
|
||||
|
||||
# Log NPU information
|
||||
for i in range(npu_count):
|
||||
props = torch.npu.get_device_properties(i)
|
||||
logger.info(f"NPU {i}: {props.name}, Total Memory: {props.total_memory / 1024**3:.2f} GB")
|
||||
else:
|
||||
logger.info("Ascend NPU not available")
|
||||
|
||||
except ImportError:
|
||||
logger.debug("torch_npu not installed. Ascend NPU support disabled.")
|
||||
except Exception as e:
|
||||
logger.warning(f"Error checking Ascend availability: {e}")
|
||||
|
||||
return _ASCEND_AVAILABLE
|
||||
|
||||
|
||||
def get_ascend_backend() -> str:
|
||||
"""
|
||||
Get the appropriate backend for Ascend NPU
|
||||
|
||||
Returns:
|
||||
str: 'hccl' for Ascend, 'nccl' for CUDA, 'gloo' for CPU
|
||||
"""
|
||||
if check_ascend_availability():
|
||||
return 'hccl'
|
||||
elif torch.cuda.is_available():
|
||||
return 'nccl'
|
||||
else:
|
||||
return 'gloo'
|
||||
|
||||
|
||||
def setup_ascend_device(local_rank: int = 0) -> torch.device:
|
||||
"""
|
||||
Setup Ascend NPU device
|
||||
|
||||
Args:
|
||||
local_rank: Local rank for distributed training
|
||||
|
||||
Returns:
|
||||
torch.device: Configured device
|
||||
"""
|
||||
if not check_ascend_availability():
|
||||
raise RuntimeError("Ascend NPU not available")
|
||||
|
||||
try:
|
||||
# Set NPU device
|
||||
torch.npu.set_device(local_rank)
|
||||
device = torch.device(f'npu:{local_rank}')
|
||||
logger.info(f"Using Ascend NPU device: {device}")
|
||||
return device
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to setup Ascend device: {e}")
|
||||
raise
|
||||
|
||||
|
||||
def convert_model_for_ascend(model: torch.nn.Module) -> torch.nn.Module:
|
||||
"""
|
||||
Convert model for Ascend NPU compatibility
|
||||
|
||||
Args:
|
||||
model: PyTorch model
|
||||
|
||||
Returns:
|
||||
Model configured for Ascend
|
||||
"""
|
||||
if not check_ascend_availability():
|
||||
return model
|
||||
|
||||
try:
|
||||
import torch_npu
|
||||
|
||||
# Apply Ascend-specific optimizations
|
||||
model = torch_npu.contrib.transfer_to_npu(model)
|
||||
|
||||
# Enable Ascend graph mode if available
|
||||
if hasattr(torch_npu, 'enable_graph_mode'):
|
||||
torch_npu.enable_graph_mode()
|
||||
logger.info("Enabled Ascend graph mode")
|
||||
|
||||
return model
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to apply Ascend optimizations: {e}")
|
||||
return model
|
||||
|
||||
|
||||
def optimize_for_ascend(config: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
Optimize configuration for Ascend NPU
|
||||
|
||||
Args:
|
||||
config: Training configuration
|
||||
|
||||
Returns:
|
||||
Optimized configuration
|
||||
"""
|
||||
if not check_ascend_availability():
|
||||
return config
|
||||
|
||||
# Create a copy to avoid modifying original
|
||||
config = config.copy()
|
||||
|
||||
# Ascend-specific optimizations
|
||||
ascend_config = config.get('ascend', {})
|
||||
|
||||
# Set backend to hccl for distributed training
|
||||
if ascend_config.get('backend') == 'hccl':
|
||||
config['distributed'] = config.get('distributed', {})
|
||||
config['distributed']['backend'] = 'hccl'
|
||||
|
||||
# Enable mixed precision if specified
|
||||
if ascend_config.get('mixed_precision', False):
|
||||
config['optimization'] = config.get('optimization', {})
|
||||
config['optimization']['fp16'] = True
|
||||
logger.info("Enabled mixed precision for Ascend")
|
||||
|
||||
# Set visible devices
|
||||
if 'visible_devices' in ascend_config:
|
||||
os.environ['ASCEND_RT_VISIBLE_DEVICES'] = ascend_config['visible_devices']
|
||||
logger.info(f"Set ASCEND_RT_VISIBLE_DEVICES={ascend_config['visible_devices']}")
|
||||
|
||||
return config
|
||||
|
||||
|
||||
def get_ascend_memory_info(device_id: int = 0) -> Tuple[float, float]:
|
||||
"""
|
||||
Get Ascend NPU memory information
|
||||
|
||||
Args:
|
||||
device_id: NPU device ID
|
||||
|
||||
Returns:
|
||||
Tuple of (used_memory_gb, total_memory_gb)
|
||||
"""
|
||||
if not check_ascend_availability():
|
||||
return 0.0, 0.0
|
||||
|
||||
try:
|
||||
import torch_npu
|
||||
|
||||
# Get memory info
|
||||
used = torch_npu.memory_allocated(device_id) / 1024**3
|
||||
total = torch_npu.get_device_properties(device_id).total_memory / 1024**3
|
||||
|
||||
return used, total
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get Ascend memory info: {e}")
|
||||
return 0.0, 0.0
|
||||
|
||||
|
||||
def clear_ascend_cache():
|
||||
"""Clear Ascend NPU memory cache"""
|
||||
if not check_ascend_availability():
|
||||
return
|
||||
|
||||
try:
|
||||
torch.npu.empty_cache()
|
||||
logger.debug("Cleared Ascend NPU cache")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to clear Ascend cache: {e}")
|
||||
|
||||
|
||||
def is_ascend_available() -> bool:
|
||||
"""
|
||||
Simple check for Ascend availability
|
||||
|
||||
Returns:
|
||||
bool: True if Ascend NPU is available
|
||||
"""
|
||||
return check_ascend_availability()
|
||||
|
||||
|
||||
class AscendDataLoader:
|
||||
"""
|
||||
Wrapper for DataLoader with Ascend-specific optimizations
|
||||
"""
|
||||
|
||||
def __init__(self, dataloader, device):
|
||||
self.dataloader = dataloader
|
||||
self.device = device
|
||||
|
||||
def __iter__(self):
|
||||
for batch in self.dataloader:
|
||||
# Move batch to NPU device
|
||||
if isinstance(batch, dict):
|
||||
batch = {k: v.to(self.device) if isinstance(v, torch.Tensor) else v
|
||||
for k, v in batch.items()}
|
||||
elif isinstance(batch, (list, tuple)):
|
||||
batch = [v.to(self.device) if isinstance(v, torch.Tensor) else v
|
||||
for v in batch]
|
||||
else:
|
||||
batch = batch.to(self.device)
|
||||
yield batch
|
||||
|
||||
def __len__(self):
|
||||
return len(self.dataloader)
|
||||
|
||||
|
||||
# Compatibility functions
|
||||
def npu_synchronize():
|
||||
"""Synchronize NPU device (compatibility wrapper)"""
|
||||
if check_ascend_availability():
|
||||
try:
|
||||
torch.npu.synchronize()
|
||||
except Exception as e:
|
||||
logger.warning(f"NPU synchronize failed: {e}")
|
||||
elif torch.cuda.is_available():
|
||||
torch.cuda.synchronize()
|
||||
|
||||
|
||||
def get_device_name(device: torch.device) -> str:
|
||||
"""
|
||||
Get device name with Ascend support
|
||||
|
||||
Args:
|
||||
device: PyTorch device
|
||||
|
||||
Returns:
|
||||
str: Device name
|
||||
"""
|
||||
if device.type == 'npu':
|
||||
if check_ascend_availability():
|
||||
try:
|
||||
props = torch.npu.get_device_properties(device.index or 0)
|
||||
return f"Ascend {props.name}"
|
||||
except:
|
||||
return "Ascend NPU"
|
||||
return "NPU (unavailable)"
|
||||
elif device.type == 'cuda':
|
||||
return torch.cuda.get_device_name(device.index or 0)
|
||||
else:
|
||||
return str(device)
|
||||
465
utils/checkpoint.py
Normal file
465
utils/checkpoint.py
Normal file
@@ -0,0 +1,465 @@
|
||||
import os
|
||||
import torch
|
||||
import json
|
||||
import shutil
|
||||
from typing import Dict, Any, Optional, List
|
||||
import logging
|
||||
from datetime import datetime
|
||||
import glob
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CheckpointManager:
|
||||
"""
|
||||
Manages model checkpoints with automatic cleanup and best model tracking
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
output_dir: str,
|
||||
save_total_limit: int = 3,
|
||||
best_model_dir: str = "best_model",
|
||||
save_optimizer: bool = True,
|
||||
save_scheduler: bool = True
|
||||
):
|
||||
self.output_dir = output_dir
|
||||
self.save_total_limit = save_total_limit
|
||||
self.best_model_dir = os.path.join(output_dir, best_model_dir)
|
||||
self.save_optimizer = save_optimizer
|
||||
self.save_scheduler = save_scheduler
|
||||
|
||||
# Create directories
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
os.makedirs(self.best_model_dir, exist_ok=True)
|
||||
|
||||
# Track checkpoints
|
||||
self.checkpoints = []
|
||||
self._load_existing_checkpoints()
|
||||
|
||||
def _load_existing_checkpoints(self):
|
||||
"""Load list of existing checkpoints"""
|
||||
checkpoint_dirs = glob.glob(os.path.join(self.output_dir, "checkpoint-*"))
|
||||
|
||||
# Sort by step number
|
||||
self.checkpoints = []
|
||||
for cp_dir in checkpoint_dirs:
|
||||
try:
|
||||
step = int(cp_dir.split("-")[-1])
|
||||
self.checkpoints.append((step, cp_dir))
|
||||
except:
|
||||
pass
|
||||
|
||||
self.checkpoints.sort(key=lambda x: x[0])
|
||||
|
||||
def save(
|
||||
self,
|
||||
state_dict: Dict[str, Any],
|
||||
step: int,
|
||||
is_best: bool = False,
|
||||
metrics: Optional[Dict[str, float]] = None
|
||||
) -> str:
|
||||
"""
|
||||
Save checkpoint with atomic writes and error handling
|
||||
|
||||
Args:
|
||||
state_dict: State dictionary containing model, optimizer, etc.
|
||||
step: Current training step
|
||||
is_best: Whether this is the best model so far
|
||||
metrics: Optional metrics to save with checkpoint
|
||||
|
||||
Returns:
|
||||
Path to saved checkpoint
|
||||
"""
|
||||
checkpoint_dir = os.path.join(self.output_dir, f"checkpoint-{step}")
|
||||
temp_dir = os.path.join(self.output_dir, f"checkpoint-{step}-tmp")
|
||||
|
||||
try:
|
||||
# Create temporary directory for atomic write
|
||||
os.makedirs(temp_dir, exist_ok=True)
|
||||
|
||||
# Save individual components
|
||||
checkpoint_files = {}
|
||||
|
||||
# Save model with error handling
|
||||
if 'model_state_dict' in state_dict:
|
||||
model_path = os.path.join(temp_dir, 'pytorch_model.bin')
|
||||
self._safe_torch_save(state_dict['model_state_dict'], model_path)
|
||||
checkpoint_files['model'] = 'pytorch_model.bin'
|
||||
|
||||
# Save optimizer
|
||||
if self.save_optimizer and 'optimizer_state_dict' in state_dict:
|
||||
optimizer_path = os.path.join(temp_dir, 'optimizer.pt')
|
||||
self._safe_torch_save(state_dict['optimizer_state_dict'], optimizer_path)
|
||||
checkpoint_files['optimizer'] = 'optimizer.pt'
|
||||
|
||||
# Save scheduler
|
||||
if self.save_scheduler and 'scheduler_state_dict' in state_dict:
|
||||
scheduler_path = os.path.join(temp_dir, 'scheduler.pt')
|
||||
self._safe_torch_save(state_dict['scheduler_state_dict'], scheduler_path)
|
||||
checkpoint_files['scheduler'] = 'scheduler.pt'
|
||||
|
||||
# Save training state
|
||||
training_state = {
|
||||
'step': step,
|
||||
'checkpoint_files': checkpoint_files,
|
||||
'timestamp': datetime.now().isoformat()
|
||||
}
|
||||
|
||||
# Add additional state info
|
||||
for key in ['global_step', 'current_epoch', 'best_metric']:
|
||||
if key in state_dict:
|
||||
training_state[key] = state_dict[key]
|
||||
|
||||
if metrics:
|
||||
training_state['metrics'] = metrics
|
||||
|
||||
# Save config if provided
|
||||
if 'config' in state_dict:
|
||||
config_path = os.path.join(temp_dir, 'config.json')
|
||||
self._safe_json_save(state_dict['config'], config_path)
|
||||
|
||||
# Save training state
|
||||
state_path = os.path.join(temp_dir, 'training_state.json')
|
||||
self._safe_json_save(training_state, state_path)
|
||||
|
||||
# Atomic rename from temp to final directory
|
||||
if os.path.exists(checkpoint_dir):
|
||||
# Backup existing checkpoint
|
||||
backup_dir = f"{checkpoint_dir}.backup"
|
||||
if os.path.exists(backup_dir):
|
||||
shutil.rmtree(backup_dir)
|
||||
shutil.move(checkpoint_dir, backup_dir)
|
||||
|
||||
# Move temp directory to final location
|
||||
shutil.move(temp_dir, checkpoint_dir)
|
||||
|
||||
# Remove backup if successful
|
||||
if os.path.exists(f"{checkpoint_dir}.backup"):
|
||||
shutil.rmtree(f"{checkpoint_dir}.backup")
|
||||
|
||||
# Add to checkpoint list
|
||||
self.checkpoints.append((step, checkpoint_dir))
|
||||
|
||||
# Save best model
|
||||
if is_best:
|
||||
self._save_best_model(checkpoint_dir, step, metrics)
|
||||
|
||||
# Cleanup old checkpoints
|
||||
self._cleanup_checkpoints()
|
||||
|
||||
logger.info(f"Saved checkpoint to {checkpoint_dir}")
|
||||
return checkpoint_dir
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save checkpoint: {e}")
|
||||
# Cleanup temp directory
|
||||
if os.path.exists(temp_dir):
|
||||
shutil.rmtree(temp_dir)
|
||||
# Restore backup if exists
|
||||
backup_dir = f"{checkpoint_dir}.backup"
|
||||
if os.path.exists(backup_dir):
|
||||
if os.path.exists(checkpoint_dir):
|
||||
shutil.rmtree(checkpoint_dir)
|
||||
shutil.move(backup_dir, checkpoint_dir)
|
||||
raise RuntimeError(f"Checkpoint save failed: {e}")
|
||||
|
||||
def _save_best_model(
|
||||
self,
|
||||
checkpoint_dir: str,
|
||||
step: int,
|
||||
metrics: Optional[Dict[str, float]] = None
|
||||
):
|
||||
"""Save the best model"""
|
||||
# Copy checkpoint to best model directory
|
||||
for filename in os.listdir(checkpoint_dir):
|
||||
src = os.path.join(checkpoint_dir, filename)
|
||||
dst = os.path.join(self.best_model_dir, filename)
|
||||
if os.path.isfile(src):
|
||||
shutil.copy2(src, dst)
|
||||
|
||||
# Save best model info
|
||||
best_info = {
|
||||
'step': step,
|
||||
'source_checkpoint': checkpoint_dir,
|
||||
'timestamp': datetime.now().isoformat()
|
||||
}
|
||||
if metrics:
|
||||
best_info['metrics'] = metrics
|
||||
|
||||
best_info_path = os.path.join(self.best_model_dir, 'best_model_info.json')
|
||||
with open(best_info_path, 'w') as f:
|
||||
json.dump(best_info, f, indent=2)
|
||||
|
||||
logger.info(f"Saved best model from step {step}")
|
||||
|
||||
def _cleanup_checkpoints(self):
|
||||
"""Remove old checkpoints to maintain save_total_limit"""
|
||||
if self.save_total_limit is not None and len(self.checkpoints) > self.save_total_limit:
|
||||
# Remove oldest checkpoints
|
||||
checkpoints_to_remove = self.checkpoints[:-self.save_total_limit]
|
||||
|
||||
for step, checkpoint_dir in checkpoints_to_remove:
|
||||
if os.path.exists(checkpoint_dir):
|
||||
shutil.rmtree(checkpoint_dir)
|
||||
logger.info(f"Removed old checkpoint: {checkpoint_dir}")
|
||||
|
||||
# Update checkpoint list
|
||||
self.checkpoints = self.checkpoints[-self.save_total_limit:]
|
||||
|
||||
def load_checkpoint(
|
||||
self,
|
||||
checkpoint_path: Optional[str] = None,
|
||||
load_best: bool = False,
|
||||
map_location: str = 'cpu',
|
||||
strict: bool = True
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Load checkpoint with robust error handling
|
||||
|
||||
Args:
|
||||
checkpoint_path: Path to specific checkpoint
|
||||
load_best: Whether to load best model
|
||||
map_location: Device to map tensors to
|
||||
strict: Whether to strictly enforce state dict matching
|
||||
|
||||
Returns:
|
||||
Loaded state dictionary
|
||||
|
||||
Raises:
|
||||
RuntimeError: If checkpoint loading fails
|
||||
"""
|
||||
try:
|
||||
if load_best:
|
||||
checkpoint_path = self.best_model_dir
|
||||
elif checkpoint_path is None:
|
||||
# Load latest checkpoint
|
||||
if self.checkpoints:
|
||||
checkpoint_path = self.checkpoints[-1][1]
|
||||
else:
|
||||
raise ValueError("No checkpoints found")
|
||||
|
||||
# Ensure checkpoint_path is a string
|
||||
if checkpoint_path is None:
|
||||
raise ValueError("Checkpoint path is None")
|
||||
|
||||
checkpoint_path = str(checkpoint_path) # Ensure it's a string
|
||||
|
||||
if not os.path.exists(checkpoint_path):
|
||||
raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}")
|
||||
|
||||
logger.info(f"Loading checkpoint from {checkpoint_path}")
|
||||
|
||||
# Load training state with error handling
|
||||
state_dict = {}
|
||||
state_path = os.path.join(checkpoint_path, 'training_state.json')
|
||||
if os.path.exists(state_path):
|
||||
try:
|
||||
with open(state_path, 'r') as f:
|
||||
training_state = json.load(f)
|
||||
state_dict.update(training_state)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to load training state: {e}")
|
||||
training_state = {}
|
||||
else:
|
||||
training_state = {}
|
||||
|
||||
# Load model state
|
||||
model_path = os.path.join(checkpoint_path, 'pytorch_model.bin')
|
||||
if os.path.exists(model_path):
|
||||
try:
|
||||
state_dict['model_state_dict'] = torch.load(
|
||||
model_path,
|
||||
map_location=map_location,
|
||||
weights_only=True # Security: prevent arbitrary code execution
|
||||
)
|
||||
except Exception as e:
|
||||
# Try legacy loading without weights_only for backward compatibility
|
||||
try:
|
||||
state_dict['model_state_dict'] = torch.load(
|
||||
model_path,
|
||||
map_location=map_location
|
||||
)
|
||||
logger.warning("Loaded checkpoint using legacy mode (weights_only=False)")
|
||||
except Exception as legacy_e:
|
||||
raise RuntimeError(f"Failed to load model state: {e}, Legacy attempt: {legacy_e}")
|
||||
else:
|
||||
logger.warning(f"Model state not found at {model_path}")
|
||||
|
||||
# Load optimizer state
|
||||
optimizer_path = os.path.join(checkpoint_path, 'optimizer.pt')
|
||||
if os.path.exists(optimizer_path):
|
||||
try:
|
||||
state_dict['optimizer_state_dict'] = torch.load(
|
||||
optimizer_path,
|
||||
map_location=map_location,
|
||||
weights_only=True
|
||||
)
|
||||
except Exception as e:
|
||||
# Try legacy loading
|
||||
try:
|
||||
state_dict['optimizer_state_dict'] = torch.load(
|
||||
optimizer_path,
|
||||
map_location=map_location
|
||||
)
|
||||
except Exception as legacy_e:
|
||||
logger.warning(f"Failed to load optimizer state: {e}")
|
||||
|
||||
# Load scheduler state
|
||||
scheduler_path = os.path.join(checkpoint_path, 'scheduler.pt')
|
||||
if os.path.exists(scheduler_path):
|
||||
try:
|
||||
state_dict['scheduler_state_dict'] = torch.load(
|
||||
scheduler_path,
|
||||
map_location=map_location,
|
||||
weights_only=True
|
||||
)
|
||||
except Exception as e:
|
||||
# Try legacy loading
|
||||
try:
|
||||
state_dict['scheduler_state_dict'] = torch.load(
|
||||
scheduler_path,
|
||||
map_location=map_location
|
||||
)
|
||||
except Exception as legacy_e:
|
||||
logger.warning(f"Failed to load scheduler state: {e}")
|
||||
|
||||
# Load config if available
|
||||
config_path = os.path.join(checkpoint_path, 'config.json')
|
||||
if os.path.exists(config_path):
|
||||
try:
|
||||
with open(config_path, 'r') as f:
|
||||
state_dict['config'] = json.load(f)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to load config: {e}")
|
||||
|
||||
logger.info(f"Successfully loaded checkpoint from {checkpoint_path}")
|
||||
return state_dict
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load checkpoint: {e}")
|
||||
raise RuntimeError(f"Checkpoint loading failed: {e}")
|
||||
|
||||
def get_latest_checkpoint(self) -> Optional[str]:
|
||||
"""Get path to latest checkpoint"""
|
||||
if self.checkpoints:
|
||||
return self.checkpoints[-1][1]
|
||||
return None
|
||||
|
||||
def get_best_checkpoint(self) -> str:
|
||||
"""Get path to best model checkpoint"""
|
||||
return self.best_model_dir
|
||||
|
||||
def checkpoint_exists(self, step: int) -> bool:
|
||||
"""Check if checkpoint exists for given step"""
|
||||
checkpoint_dir = os.path.join(self.output_dir, f"checkpoint-{step}")
|
||||
return os.path.exists(checkpoint_dir)
|
||||
|
||||
def _safe_torch_save(self, obj: Any, path: str):
|
||||
"""Save torch object with error handling"""
|
||||
temp_path = f"{path}.tmp"
|
||||
try:
|
||||
torch.save(obj, temp_path)
|
||||
# Atomic rename
|
||||
os.replace(temp_path, path)
|
||||
except Exception as e:
|
||||
if os.path.exists(temp_path):
|
||||
os.remove(temp_path)
|
||||
raise RuntimeError(f"Failed to save {path}: {e}")
|
||||
|
||||
def _safe_json_save(self, obj: Any, path: str):
|
||||
"""Save JSON with atomic write"""
|
||||
temp_path = f"{path}.tmp"
|
||||
try:
|
||||
with open(temp_path, 'w') as f:
|
||||
json.dump(obj, f, indent=2)
|
||||
# Atomic rename
|
||||
os.replace(temp_path, path)
|
||||
except Exception as e:
|
||||
if os.path.exists(temp_path):
|
||||
os.remove(temp_path)
|
||||
raise RuntimeError(f"Failed to save {path}: {e}")
|
||||
|
||||
|
||||
def save_model_for_inference(
|
||||
model: torch.nn.Module,
|
||||
tokenizer: Any,
|
||||
output_dir: str,
|
||||
model_name: str = "model",
|
||||
save_config: bool = True,
|
||||
additional_files: Optional[Dict[str, str]] = None
|
||||
):
|
||||
"""
|
||||
Save model in a format ready for inference
|
||||
|
||||
Args:
|
||||
model: The model to save
|
||||
tokenizer: Tokenizer to save
|
||||
output_dir: Output directory
|
||||
model_name: Name for the model
|
||||
save_config: Whether to save model config
|
||||
additional_files: Additional files to include
|
||||
"""
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
# Save model
|
||||
model_path = os.path.join(output_dir, f"{model_name}.pt")
|
||||
torch.save(model.state_dict(), model_path)
|
||||
logger.info(f"Saved model to {model_path}")
|
||||
|
||||
# Save tokenizer
|
||||
if hasattr(tokenizer, 'save_pretrained'):
|
||||
tokenizer.save_pretrained(output_dir)
|
||||
logger.info(f"Saved tokenizer to {output_dir}")
|
||||
|
||||
# Save config if available
|
||||
if save_config and hasattr(model, 'config'):
|
||||
config_path = os.path.join(output_dir, 'config.json')
|
||||
model.config.save_pretrained(output_dir) # type: ignore[attr-defined]
|
||||
|
||||
# Save additional files
|
||||
if additional_files:
|
||||
for filename, content in additional_files.items():
|
||||
filepath = os.path.join(output_dir, filename)
|
||||
with open(filepath, 'w') as f:
|
||||
if isinstance(content, dict):
|
||||
json.dump(content, f, indent=2)
|
||||
else:
|
||||
f.write(str(content))
|
||||
|
||||
# Create README
|
||||
readme_content = f"""# {model_name} Model
|
||||
|
||||
This directory contains a fine-tuned model ready for inference.
|
||||
|
||||
## Files:
|
||||
- `{model_name}.pt`: Model weights
|
||||
- `tokenizer_config.json`: Tokenizer configuration
|
||||
- `special_tokens_map.json`: Special tokens mapping
|
||||
- `vocab.txt` or `tokenizer.json`: Vocabulary file
|
||||
|
||||
## Loading the model:
|
||||
|
||||
```python
|
||||
import torch
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
# Load tokenizer
|
||||
tokenizer = AutoTokenizer.from_pretrained('{output_dir}')
|
||||
|
||||
# Load model
|
||||
model = YourModelClass() # Initialize your model class
|
||||
model.load_state_dict(torch.load('{model_name}.pt'))
|
||||
model.eval()
|
||||
```
|
||||
|
||||
## Model Information:
|
||||
- Model type: {model.__class__.__name__}
|
||||
- Saved on: {datetime.now().isoformat()}
|
||||
"""
|
||||
|
||||
readme_path = os.path.join(output_dir, 'README.md')
|
||||
with open(readme_path, 'w') as f:
|
||||
f.write(readme_content)
|
||||
|
||||
logger.info(f"Model saved for inference in {output_dir}")
|
||||
518
utils/config_loader.py
Normal file
518
utils/config_loader.py
Normal file
@@ -0,0 +1,518 @@
|
||||
"""
|
||||
Centralized configuration loader for BGE fine-tuning project
|
||||
"""
|
||||
import os
|
||||
import toml
|
||||
import logging
|
||||
from typing import Dict, Any, Optional, Tuple
|
||||
from pathlib import Path
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def resolve_model_path(
|
||||
model_key: str,
|
||||
config_instance: 'ConfigLoader',
|
||||
model_name_or_path: Optional[str] = None,
|
||||
cache_dir: Optional[str] = None
|
||||
) -> Tuple[str, str]:
|
||||
"""
|
||||
Resolve model path with automatic local/remote fallback
|
||||
|
||||
Args:
|
||||
model_key: Key in model_paths config (e.g., 'bge_m3', 'bge_reranker')
|
||||
config_instance: ConfigLoader instance
|
||||
model_name_or_path: Optional override from CLI
|
||||
cache_dir: Optional cache directory
|
||||
|
||||
Returns:
|
||||
Tuple of (resolved_model_path, resolved_cache_dir)
|
||||
"""
|
||||
# Default remote model names for each model type
|
||||
remote_models = {
|
||||
'bge_m3': 'BAAI/bge-m3',
|
||||
'bge_reranker': 'BAAI/bge-reranker-base',
|
||||
}
|
||||
|
||||
# Get model source (huggingface or modelscope)
|
||||
model_source = config_instance.get('model_paths', 'source', 'huggingface')
|
||||
|
||||
# If ModelScope is used, adjust model names
|
||||
if model_source == 'modelscope':
|
||||
remote_models = {
|
||||
'bge_m3': 'damo/nlp_gte_sentence-embedding_chinese-base',
|
||||
'bge_reranker': 'damo/nlp_corom_passage-ranking_english-base',
|
||||
}
|
||||
|
||||
# 1. Check CLI override first (only if explicitly provided)
|
||||
if model_name_or_path is not None:
|
||||
resolved_model_path = model_name_or_path
|
||||
logger.info(f"[CLI OVERRIDE] Using model path: {resolved_model_path}")
|
||||
else:
|
||||
# 2. Check local path from config
|
||||
local_path_raw = config_instance.get('model_paths', model_key, None)
|
||||
|
||||
# Handle list/string conversion
|
||||
if isinstance(local_path_raw, list):
|
||||
local_path = local_path_raw[0] if local_path_raw else None
|
||||
else:
|
||||
local_path = local_path_raw
|
||||
|
||||
if local_path and isinstance(local_path, str) and os.path.exists(local_path):
|
||||
# Check if it's a valid model directory
|
||||
if _is_valid_model_directory(local_path):
|
||||
resolved_model_path = local_path
|
||||
logger.info(f"[LOCAL] Using local model: {resolved_model_path}")
|
||||
else:
|
||||
logger.warning(f"[LOCAL] Invalid model directory: {local_path}, falling back to remote")
|
||||
resolved_model_path = remote_models.get(model_key, f'BAAI/{model_key}')
|
||||
else:
|
||||
# 3. Fall back to remote model
|
||||
resolved_model_path = remote_models.get(model_key, f'BAAI/{model_key}')
|
||||
logger.info(f"[REMOTE] Using remote model: {resolved_model_path}")
|
||||
|
||||
# If local path was configured but doesn't exist, we'll download to that location
|
||||
if local_path and isinstance(local_path, str):
|
||||
logger.info(f"[DOWNLOAD] Will download to configured path: {local_path}")
|
||||
# Create directory if it doesn't exist
|
||||
os.makedirs(local_path, exist_ok=True)
|
||||
|
||||
# Resolve cache directory
|
||||
if cache_dir is not None:
|
||||
resolved_cache_dir = cache_dir
|
||||
else:
|
||||
cache_dir_raw = config_instance.get('model_paths', 'cache_dir', './cache/data')
|
||||
if isinstance(cache_dir_raw, list):
|
||||
resolved_cache_dir = cache_dir_raw[0] if cache_dir_raw else './cache/data'
|
||||
else:
|
||||
resolved_cache_dir = cache_dir_raw if cache_dir_raw else './cache/data'
|
||||
|
||||
# Ensure cache directory exists
|
||||
if isinstance(resolved_cache_dir, str):
|
||||
os.makedirs(resolved_cache_dir, exist_ok=True)
|
||||
|
||||
return resolved_model_path, resolved_cache_dir
|
||||
|
||||
|
||||
def _is_valid_model_directory(path: str) -> bool:
|
||||
"""Check if a directory contains a valid model"""
|
||||
required_files = ['config.json']
|
||||
optional_files = ['tokenizer_config.json', 'vocab.txt', 'tokenizer.json']
|
||||
model_files = ['pytorch_model.bin', 'model.safetensors', 'pytorch_model.bin.index.json']
|
||||
|
||||
path_obj = Path(path)
|
||||
if not path_obj.exists() or not path_obj.is_dir():
|
||||
return False
|
||||
|
||||
# Check for required files
|
||||
for file in required_files:
|
||||
if not (path_obj / file).exists():
|
||||
return False
|
||||
|
||||
# Check for at least one model file
|
||||
has_model_file = any((path_obj / file).exists() for file in model_files)
|
||||
if not has_model_file:
|
||||
return False
|
||||
|
||||
# Check for at least one tokenizer file (for embedding models)
|
||||
has_tokenizer_file = any((path_obj / file).exists() for file in optional_files)
|
||||
if not has_tokenizer_file:
|
||||
logger.warning(f"Model directory {path} missing tokenizer files, but proceeding anyway")
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def download_model_if_needed(
|
||||
model_name_or_path: str,
|
||||
local_path: Optional[str] = None,
|
||||
cache_dir: Optional[str] = None,
|
||||
model_source: str = 'huggingface'
|
||||
) -> str:
|
||||
"""
|
||||
Download model if it's a remote identifier and local path is specified
|
||||
|
||||
Args:
|
||||
model_name_or_path: Model identifier or path
|
||||
local_path: Target local path for download
|
||||
cache_dir: Cache directory
|
||||
model_source: 'huggingface' or 'modelscope'
|
||||
|
||||
Returns:
|
||||
Final model path (local if downloaded, original if already local)
|
||||
"""
|
||||
# If it's already a local path, return as is
|
||||
if os.path.exists(model_name_or_path):
|
||||
return model_name_or_path
|
||||
|
||||
# If no local path specified, return original (will be handled by transformers)
|
||||
if not local_path:
|
||||
return model_name_or_path
|
||||
|
||||
# Check if local path already exists and is valid
|
||||
if os.path.exists(local_path) and _is_valid_model_directory(local_path):
|
||||
logger.info(f"Model already exists locally: {local_path}")
|
||||
return local_path
|
||||
|
||||
# Download the model
|
||||
try:
|
||||
logger.info(f"Downloading model {model_name_or_path} to {local_path}")
|
||||
|
||||
if model_source == 'huggingface':
|
||||
from transformers import AutoModel, AutoTokenizer
|
||||
|
||||
# Download model and tokenizer
|
||||
model = AutoModel.from_pretrained(
|
||||
model_name_or_path,
|
||||
cache_dir=cache_dir,
|
||||
trust_remote_code=True
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_name_or_path,
|
||||
cache_dir=cache_dir,
|
||||
trust_remote_code=True
|
||||
)
|
||||
|
||||
# Save to local path
|
||||
os.makedirs(local_path, exist_ok=True)
|
||||
model.save_pretrained(local_path)
|
||||
tokenizer.save_pretrained(local_path)
|
||||
|
||||
logger.info(f"Model downloaded successfully to {local_path}")
|
||||
return local_path
|
||||
|
||||
elif model_source == 'modelscope':
|
||||
try:
|
||||
from modelscope import snapshot_download # type: ignore
|
||||
|
||||
# Download using ModelScope
|
||||
downloaded_path = snapshot_download(
|
||||
model_name_or_path,
|
||||
cache_dir=cache_dir,
|
||||
local_dir=local_path
|
||||
)
|
||||
|
||||
logger.info(f"Model downloaded successfully from ModelScope to {downloaded_path}")
|
||||
return downloaded_path
|
||||
|
||||
except ImportError:
|
||||
logger.warning("ModelScope not available, falling back to HuggingFace")
|
||||
return download_model_if_needed(
|
||||
model_name_or_path, local_path, cache_dir, 'huggingface'
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to download model {model_name_or_path}: {e}")
|
||||
logger.info("Returning original model path - transformers will handle caching")
|
||||
|
||||
# Fallback return
|
||||
return model_name_or_path
|
||||
|
||||
|
||||
class ConfigLoader:
|
||||
"""Centralized configuration management for BGE fine-tuning
|
||||
|
||||
Parameter Precedence (highest to lowest):
|
||||
1. Command-line arguments (CLI)
|
||||
2. Model-specific config ([m3], [reranker])
|
||||
3. [hardware] section in config.toml
|
||||
4. [training] section in config.toml
|
||||
5. Hardcoded defaults in code
|
||||
"""
|
||||
|
||||
_instance = None
|
||||
_config = None
|
||||
|
||||
def __new__(cls):
|
||||
if cls._instance is None:
|
||||
cls._instance = super(ConfigLoader, cls).__new__(cls)
|
||||
return cls._instance
|
||||
|
||||
def __init__(self):
|
||||
if self._config is None:
|
||||
self._load_config()
|
||||
|
||||
def _find_config_file(self) -> str:
|
||||
"""Find config.toml file in project hierarchy"""
|
||||
# Check environment variable FIRST (highest priority)
|
||||
if 'BGE_CONFIG_PATH' in os.environ:
|
||||
config_path = Path(os.environ['BGE_CONFIG_PATH'])
|
||||
if config_path.exists():
|
||||
return str(config_path)
|
||||
else:
|
||||
logger.warning(f"BGE_CONFIG_PATH set to {config_path} but file doesn't exist, falling back to default")
|
||||
|
||||
current_dir = Path(__file__).parent
|
||||
|
||||
# Search up the directory tree for config.toml
|
||||
for _ in range(5): # Limit search depth
|
||||
config_path = current_dir / "config.toml"
|
||||
if config_path.exists():
|
||||
return str(config_path)
|
||||
current_dir = current_dir.parent
|
||||
|
||||
# Default location
|
||||
return os.path.join(os.getcwd(), "config.toml")
|
||||
|
||||
def _load_config(self):
|
||||
"""Load configuration from TOML file"""
|
||||
config_path = self._find_config_file()
|
||||
|
||||
try:
|
||||
if os.path.exists(config_path):
|
||||
with open(config_path, 'r', encoding='utf-8') as f:
|
||||
loaded = toml.load(f)
|
||||
if loaded is None:
|
||||
loaded = {}
|
||||
self._config = dict(loaded)
|
||||
logger.info(f"Loaded configuration from {config_path}")
|
||||
else:
|
||||
logger.warning(f"Config file not found at {config_path}, using defaults")
|
||||
self._config = self._get_default_config()
|
||||
self._save_default_config(config_path)
|
||||
logger.warning("Default config created. Please check your configuration file.")
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading config: {e}")
|
||||
self._config = self._get_default_config()
|
||||
|
||||
# Validation for unset/placeholder values
|
||||
self._validate_config()
|
||||
|
||||
def _validate_config(self):
|
||||
"""Validate config for unset/placeholder values"""
|
||||
api_key = self.get('api', 'api_key', None)
|
||||
if not api_key or api_key == 'your-api-key-here':
|
||||
logger.warning("API key is unset or placeholder. Set it in config.toml or via the BGE_API_KEY environment variable.")
|
||||
|
||||
def _get_default_config(self) -> Dict[str, Any]:
|
||||
"""Get default configuration"""
|
||||
return {
|
||||
'api': {
|
||||
'base_url': 'https://api.deepseek.com/v1',
|
||||
'api_key': 'your-api-key-here',
|
||||
'model': 'deepseek-chat',
|
||||
'max_retries': 3,
|
||||
'timeout': 30,
|
||||
'temperature': 0.7
|
||||
},
|
||||
'training': {
|
||||
'default_batch_size': 4,
|
||||
'default_learning_rate': 1e-5,
|
||||
'default_num_epochs': 3,
|
||||
'default_warmup_ratio': 0.1,
|
||||
'default_weight_decay': 0.01,
|
||||
'gradient_accumulation_steps': 4
|
||||
},
|
||||
'model_paths': {
|
||||
'bge_m3': './models/bge-m3',
|
||||
'bge_reranker': './models/bge-reranker-base',
|
||||
'cache_dir': './cache/data'
|
||||
},
|
||||
'data': {
|
||||
'max_query_length': 64,
|
||||
'max_passage_length': 512,
|
||||
'max_seq_length': 512,
|
||||
'validation_split': 0.1,
|
||||
'random_seed': 42
|
||||
},
|
||||
'augmentation': {
|
||||
'enable_query_augmentation': False,
|
||||
'enable_passage_augmentation': False,
|
||||
'augmentation_methods': ['paraphrase', 'back_translation'],
|
||||
'num_augmentations_per_sample': 2,
|
||||
'augmentation_temperature': 0.8
|
||||
},
|
||||
'hard_negative_mining': {
|
||||
'enable_mining': True,
|
||||
'num_hard_negatives': 15,
|
||||
'negative_range_min': 10,
|
||||
'negative_range_max': 100,
|
||||
'margin': 0.1,
|
||||
'index_refresh_interval': 1000,
|
||||
'index_type': 'IVF',
|
||||
'nlist': 100
|
||||
},
|
||||
'evaluation': {
|
||||
'eval_batch_size': 32,
|
||||
'k_values': [1, 3, 5, 10, 20, 50, 100],
|
||||
'metrics': ['recall', 'precision', 'map', 'mrr', 'ndcg'],
|
||||
'save_predictions': True,
|
||||
'predictions_dir': './output/predictions'
|
||||
},
|
||||
'logging': {
|
||||
'log_level': 'INFO',
|
||||
'log_to_file': True,
|
||||
'log_dir': './logs',
|
||||
'tensorboard_dir': './logs/tensorboard',
|
||||
'wandb_project': 'bge-finetuning',
|
||||
'wandb_entity': '',
|
||||
'use_wandb': False
|
||||
},
|
||||
'distributed': {
|
||||
'backend': 'nccl',
|
||||
'init_method': 'env://',
|
||||
'find_unused_parameters': True
|
||||
},
|
||||
'optimization': {
|
||||
'gradient_checkpointing': True,
|
||||
'fp16': True,
|
||||
'bf16': False,
|
||||
'fp16_opt_level': 'O1',
|
||||
'max_grad_norm': 1.0,
|
||||
'adam_epsilon': 1e-8,
|
||||
'adam_beta1': 0.9,
|
||||
'adam_beta2': 0.999
|
||||
},
|
||||
'hardware': {
|
||||
'cuda_visible_devices': '0,1,2,3',
|
||||
'dataloader_num_workers': 4,
|
||||
'dataloader_pin_memory': True,
|
||||
'per_device_train_batch_size': 4,
|
||||
'per_device_eval_batch_size': 8
|
||||
},
|
||||
'checkpoint': {
|
||||
'save_strategy': 'steps',
|
||||
'save_steps': 1000,
|
||||
'save_total_limit': 3,
|
||||
'load_best_model_at_end': True,
|
||||
'metric_for_best_model': 'eval_recall@10',
|
||||
'greater_is_better': True,
|
||||
'resume_from_checkpoint': ''
|
||||
},
|
||||
'export': {
|
||||
'export_format': ['pytorch', 'onnx'],
|
||||
'optimize_for_inference': True,
|
||||
'quantize': False,
|
||||
'quantization_bits': 8
|
||||
}
|
||||
}
|
||||
|
||||
def _save_default_config(self, config_path: str):
|
||||
"""Save default configuration to file"""
|
||||
try:
|
||||
os.makedirs(os.path.dirname(config_path), exist_ok=True)
|
||||
with open(config_path, 'w', encoding='utf-8') as f:
|
||||
config_to_save = self._config if isinstance(self._config, dict) else {}
|
||||
toml.dump(config_to_save, f)
|
||||
logger.info(f"Created default config at {config_path}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving default config: {e}")
|
||||
|
||||
def get(self, section: str, key: str, default=None):
|
||||
"""Get a config value, logging its source (CLI, config, default)."""
|
||||
# CLI overrides are handled outside this loader; here we check config and default
|
||||
value = default
|
||||
source = 'default'
|
||||
if isinstance(self._config, dict) and section in self._config and key in self._config[section]:
|
||||
value = self._config[section][key]
|
||||
source = 'config.toml'
|
||||
# Robust list/tuple conversion
|
||||
if isinstance(value, str) and (',' in value or ';' in value):
|
||||
try:
|
||||
value = [v.strip() for v in value.replace(';', ',').split(',') if v.strip()]
|
||||
source += ' (parsed as list)'
|
||||
except Exception as e:
|
||||
logging.warning(f"Failed to parse list for {section}.{key}: {e}")
|
||||
log_func = logging.info if source != 'default' else logging.warning
|
||||
log_func(f"Config param {section}.{key} = {value} (source: {source})")
|
||||
return value
|
||||
|
||||
def get_section(self, section: str) -> Dict[str, Any]:
|
||||
"""Get entire configuration section"""
|
||||
if isinstance(self._config, dict):
|
||||
return self._config.get(section, {})
|
||||
return {}
|
||||
|
||||
def update(self, key: str, value: Any):
|
||||
"""Update configuration value. Logs a warning if overriding an existing value."""
|
||||
if self._config is None:
|
||||
self._config = {}
|
||||
|
||||
keys = key.split('.')
|
||||
config = self._config
|
||||
|
||||
# Navigate to the parent of the final key
|
||||
for k in keys[:-1]:
|
||||
if k not in config:
|
||||
config[k] = {}
|
||||
config = config[k]
|
||||
|
||||
# Log if overriding
|
||||
old_value = config.get(keys[-1], None)
|
||||
if old_value is not None and old_value != value:
|
||||
logger.warning(f"Parameter '{key}' overridden: {old_value} -> {value} (higher-precedence source)")
|
||||
config[keys[-1]] = value
|
||||
|
||||
def reload(self):
|
||||
"""Reload configuration from file"""
|
||||
self._config = None
|
||||
self._load_config()
|
||||
|
||||
|
||||
# Global instance
|
||||
config = ConfigLoader()
|
||||
|
||||
|
||||
def get_config() -> ConfigLoader:
|
||||
"""Get global configuration instance"""
|
||||
return config
|
||||
|
||||
|
||||
def load_config_for_training(
|
||||
config_path: Optional[str] = None,
|
||||
override_params: Optional[Dict[str, Any]] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Load configuration specifically for training scripts
|
||||
|
||||
Args:
|
||||
config_path: Optional path to config file
|
||||
override_params: Parameters to override from command line
|
||||
|
||||
Returns:
|
||||
Dictionary with all configuration parameters
|
||||
"""
|
||||
global config
|
||||
|
||||
# Reload with specific path if provided
|
||||
if config_path:
|
||||
old_path = os.environ.get('BGE_CONFIG_PATH')
|
||||
os.environ['BGE_CONFIG_PATH'] = config_path
|
||||
config.reload()
|
||||
if old_path:
|
||||
os.environ['BGE_CONFIG_PATH'] = old_path
|
||||
else:
|
||||
del os.environ['BGE_CONFIG_PATH']
|
||||
|
||||
# Get all configuration
|
||||
training_config = {
|
||||
'api': config.get_section('api'),
|
||||
'training': config.get_section('training'),
|
||||
'model_paths': config.get_section('model_paths'),
|
||||
'data': config.get_section('data'),
|
||||
'augmentation': config.get_section('augmentation'),
|
||||
'hard_negative_mining': config.get_section('hard_negative_mining'),
|
||||
'evaluation': config.get_section('evaluation'),
|
||||
'logging': config.get_section('logging'),
|
||||
'distributed': config.get_section('distributed'),
|
||||
'optimization': config.get_section('optimization'),
|
||||
'hardware': config.get_section('hardware'),
|
||||
'checkpoint': config.get_section('checkpoint'),
|
||||
'export': config.get_section('export'),
|
||||
'm3': config.get_section('m3'),
|
||||
'reranker': config.get_section('reranker'),
|
||||
}
|
||||
|
||||
# Apply overrides if provided
|
||||
if override_params:
|
||||
for key, value in override_params.items():
|
||||
config.update(key, value)
|
||||
# Also update the returned dict
|
||||
keys = key.split('.')
|
||||
current = training_config
|
||||
for k in keys[:-1]:
|
||||
if k not in current:
|
||||
current[k] = {}
|
||||
current = current[k]
|
||||
current[keys[-1]] = value
|
||||
|
||||
return training_config
|
||||
425
utils/distributed.py
Normal file
425
utils/distributed.py
Normal file
@@ -0,0 +1,425 @@
|
||||
import os
|
||||
import torch
|
||||
import torch.backends.cudnn as cudnn
|
||||
import torch.distributed as dist
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.utils.data import DataLoader, DistributedSampler
|
||||
import logging
|
||||
from typing import Optional, Any, List
|
||||
import random
|
||||
import numpy as np
|
||||
from .config_loader import get_config
|
||||
|
||||
# Import Ascend support utilities
|
||||
try:
|
||||
from .ascend_support import (
|
||||
check_ascend_availability,
|
||||
get_ascend_backend,
|
||||
setup_ascend_device,
|
||||
clear_ascend_cache
|
||||
)
|
||||
ASCEND_SUPPORT_AVAILABLE = True
|
||||
except ImportError:
|
||||
ASCEND_SUPPORT_AVAILABLE = False
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def setup_distributed(config=None):
|
||||
"""Setup distributed training environment with proper error handling."""
|
||||
try:
|
||||
# Check if we're in a distributed environment
|
||||
world_size = int(os.environ.get('WORLD_SIZE', 1))
|
||||
rank = int(os.environ.get('RANK', 0))
|
||||
local_rank = int(os.environ.get('LOCAL_RANK', 0))
|
||||
|
||||
if world_size == 1:
|
||||
logger.info("Single GPU/CPU training mode")
|
||||
return
|
||||
|
||||
# Determine backend based on hardware availability
|
||||
# First check for Ascend NPU support
|
||||
if ASCEND_SUPPORT_AVAILABLE and check_ascend_availability():
|
||||
backend = 'hccl'
|
||||
try:
|
||||
setup_ascend_device(local_rank)
|
||||
logger.info("Using Ascend NPU with HCCL backend")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to setup Ascend device: {e}, falling back")
|
||||
backend = 'gloo'
|
||||
elif torch.cuda.is_available():
|
||||
backend = 'nccl'
|
||||
# Set CUDA device
|
||||
torch.cuda.set_device(local_rank)
|
||||
else:
|
||||
backend = 'gloo'
|
||||
|
||||
# Override backend from config if provided
|
||||
if config is not None:
|
||||
backend = config.get('distributed.backend', backend)
|
||||
|
||||
# Initialize process group
|
||||
init_method = os.environ.get('MASTER_ADDR', None)
|
||||
if init_method:
|
||||
init_method = f"tcp://{os.environ['MASTER_ADDR']}:{os.environ.get('MASTER_PORT', '29500')}"
|
||||
else:
|
||||
init_method = 'env://'
|
||||
|
||||
dist.init_process_group(
|
||||
backend=backend,
|
||||
init_method=init_method,
|
||||
world_size=world_size,
|
||||
rank=rank
|
||||
)
|
||||
|
||||
logger.info(f"Distributed training initialized: backend={backend}, world_size={world_size}, rank={rank}, local_rank={local_rank}")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to initialize distributed training: {e}")
|
||||
logger.info("Falling back to single GPU/CPU training")
|
||||
# Set environment variables to indicate single GPU mode
|
||||
os.environ['WORLD_SIZE'] = '1'
|
||||
os.environ['RANK'] = '0'
|
||||
os.environ['LOCAL_RANK'] = '0'
|
||||
|
||||
|
||||
def cleanup_distributed():
|
||||
"""Cleanup distributed training environment with error handling."""
|
||||
try:
|
||||
if dist.is_initialized():
|
||||
dist.destroy_process_group()
|
||||
logger.info("Distributed process group destroyed.")
|
||||
except Exception as e:
|
||||
logger.warning(f"Error during distributed cleanup: {e}")
|
||||
# Continue execution even if cleanup fails
|
||||
|
||||
|
||||
def is_main_process() -> bool:
|
||||
"""Check if this is the main process"""
|
||||
if not dist.is_initialized():
|
||||
return True
|
||||
return dist.get_rank() == 0
|
||||
|
||||
|
||||
def get_rank() -> int:
|
||||
"""Get current process rank"""
|
||||
if not dist.is_initialized():
|
||||
return 0
|
||||
return dist.get_rank()
|
||||
|
||||
|
||||
def get_world_size() -> int:
|
||||
"""Get world size"""
|
||||
if not dist.is_initialized():
|
||||
return 1
|
||||
return dist.get_world_size()
|
||||
|
||||
|
||||
def get_local_rank() -> int:
|
||||
"""Get local rank (GPU index on current node)"""
|
||||
if not dist.is_initialized():
|
||||
return 0
|
||||
return int(os.environ.get('LOCAL_RANK', 0))
|
||||
|
||||
|
||||
def synchronize():
|
||||
"""Synchronize all processes"""
|
||||
if dist.is_initialized():
|
||||
dist.barrier()
|
||||
|
||||
|
||||
def all_reduce(tensor: torch.Tensor, op=dist.ReduceOp.SUM) -> torch.Tensor:
|
||||
"""
|
||||
All-reduce a tensor across all processes
|
||||
|
||||
Args:
|
||||
tensor: Tensor to reduce
|
||||
op: Reduction operation
|
||||
|
||||
Returns:
|
||||
Reduced tensor
|
||||
"""
|
||||
if not dist.is_initialized():
|
||||
return tensor
|
||||
|
||||
dist.all_reduce(tensor, op=op)
|
||||
return tensor
|
||||
|
||||
|
||||
def all_gather(tensor: torch.Tensor) -> List[torch.Tensor]:
|
||||
"""
|
||||
All-gather tensors from all processes
|
||||
|
||||
Args:
|
||||
tensor: Tensor to gather
|
||||
|
||||
Returns:
|
||||
List of tensors from all processes
|
||||
"""
|
||||
if not dist.is_initialized():
|
||||
return [tensor]
|
||||
|
||||
world_size = get_world_size()
|
||||
tensors = [torch.empty_like(tensor) for _ in range(world_size)]
|
||||
dist.all_gather(tensors, tensor)
|
||||
return tensors
|
||||
|
||||
|
||||
def broadcast(tensor: torch.Tensor, src: int = 0) -> torch.Tensor:
|
||||
"""
|
||||
Broadcast tensor from source to all processes
|
||||
|
||||
Args:
|
||||
tensor: Tensor to broadcast
|
||||
src: Source rank
|
||||
|
||||
Returns:
|
||||
Broadcasted tensor
|
||||
"""
|
||||
if not dist.is_initialized():
|
||||
return tensor
|
||||
|
||||
dist.broadcast(tensor, src=src)
|
||||
return tensor
|
||||
|
||||
|
||||
def reduce_dict(input_dict: dict, average: bool = True) -> dict:
|
||||
"""
|
||||
Reduce a dictionary of values across all processes
|
||||
|
||||
Args:
|
||||
input_dict: Dictionary to reduce
|
||||
average: Whether to average the values
|
||||
|
||||
Returns:
|
||||
Reduced dictionary
|
||||
"""
|
||||
if not dist.is_initialized():
|
||||
return input_dict
|
||||
|
||||
world_size = get_world_size()
|
||||
if world_size < 2:
|
||||
return input_dict
|
||||
|
||||
with torch.no_grad():
|
||||
names = []
|
||||
values = []
|
||||
|
||||
for k in sorted(input_dict.keys()):
|
||||
names.append(k)
|
||||
values.append(input_dict[k])
|
||||
|
||||
values = torch.stack(values, dim=0)
|
||||
dist.all_reduce(values)
|
||||
|
||||
if average:
|
||||
values /= world_size
|
||||
|
||||
reduced_dict = {k: v.item() for k, v in zip(names, values)}
|
||||
|
||||
return reduced_dict
|
||||
|
||||
|
||||
def setup_model_for_distributed(
|
||||
model: torch.nn.Module,
|
||||
device_ids: Optional[List[int]] = None,
|
||||
output_device: Optional[int] = None,
|
||||
find_unused_parameters: bool = True,
|
||||
broadcast_buffers: bool = True
|
||||
) -> torch.nn.Module:
|
||||
"""
|
||||
Wrap model for distributed training
|
||||
|
||||
Args:
|
||||
model: Model to wrap
|
||||
device_ids: GPU indices to use
|
||||
output_device: Device where outputs are gathered
|
||||
find_unused_parameters: Whether to find unused parameters
|
||||
broadcast_buffers: Whether to broadcast buffers
|
||||
|
||||
Returns:
|
||||
Distributed model
|
||||
"""
|
||||
if not dist.is_initialized():
|
||||
return model
|
||||
|
||||
if device_ids is None:
|
||||
device_ids = [get_local_rank()]
|
||||
|
||||
if output_device is None:
|
||||
output_device = device_ids[0]
|
||||
|
||||
model = DDP(
|
||||
model,
|
||||
device_ids=device_ids,
|
||||
output_device=output_device,
|
||||
find_unused_parameters=find_unused_parameters,
|
||||
broadcast_buffers=broadcast_buffers
|
||||
)
|
||||
|
||||
logger.info(f"Wrapped model for distributed training on devices {device_ids}")
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def create_distributed_dataloader(
|
||||
dataset,
|
||||
batch_size: int,
|
||||
shuffle: bool = True,
|
||||
num_workers: int = 0,
|
||||
pin_memory: bool = True,
|
||||
drop_last: bool = False,
|
||||
seed: int = 42
|
||||
) -> DataLoader:
|
||||
"""
|
||||
Create a distributed dataloader
|
||||
|
||||
Args:
|
||||
dataset: Dataset to load
|
||||
batch_size: Batch size per GPU
|
||||
shuffle: Whether to shuffle
|
||||
num_workers: Number of data loading workers
|
||||
pin_memory: Whether to pin memory
|
||||
drop_last: Whether to drop last incomplete batch
|
||||
seed: Random seed for shuffling
|
||||
|
||||
Returns:
|
||||
Distributed dataloader
|
||||
"""
|
||||
if dist.is_initialized():
|
||||
sampler = DistributedSampler(
|
||||
dataset,
|
||||
num_replicas=get_world_size(),
|
||||
rank=get_rank(),
|
||||
shuffle=shuffle,
|
||||
seed=seed,
|
||||
drop_last=drop_last
|
||||
)
|
||||
shuffle = False # Sampler handles shuffling
|
||||
else:
|
||||
sampler = None
|
||||
|
||||
dataloader = DataLoader(
|
||||
dataset,
|
||||
batch_size=batch_size,
|
||||
shuffle=shuffle,
|
||||
sampler=sampler,
|
||||
num_workers=num_workers,
|
||||
pin_memory=pin_memory,
|
||||
drop_last=drop_last
|
||||
)
|
||||
|
||||
return dataloader
|
||||
|
||||
|
||||
def set_random_seed(seed: int, deterministic: bool = False):
|
||||
"""
|
||||
Set random seed for reproducibility
|
||||
|
||||
Args:
|
||||
seed: Random seed
|
||||
deterministic: Whether to use deterministic algorithms
|
||||
"""
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
|
||||
if deterministic and torch.cuda.is_available():
|
||||
# Configure cuDNN determinism settings
|
||||
cudnn.deterministic = True
|
||||
cudnn.benchmark = False
|
||||
|
||||
logger.info(f"Set random seed to {seed} (deterministic={deterministic})")
|
||||
|
||||
|
||||
def print_distributed_info():
|
||||
"""Print distributed training information"""
|
||||
if not dist.is_initialized():
|
||||
logger.info("Distributed training not initialized")
|
||||
return
|
||||
|
||||
logger.info(f"Distributed Info:")
|
||||
logger.info(f" Backend: {dist.get_backend()}")
|
||||
logger.info(f" World size: {get_world_size()}")
|
||||
logger.info(f" Rank: {get_rank()}")
|
||||
logger.info(f" Local rank: {get_local_rank()}")
|
||||
logger.info(f" Master addr: {os.environ.get('MASTER_ADDR', 'N/A')}")
|
||||
logger.info(f" Master port: {os.environ.get('MASTER_PORT', 'N/A')}")
|
||||
|
||||
|
||||
class DistributedMetricAggregator:
|
||||
"""Helper class to aggregate metrics across distributed processes"""
|
||||
|
||||
def __init__(self):
|
||||
self.metrics = {}
|
||||
|
||||
def update(self, **kwargs):
|
||||
"""Update metrics"""
|
||||
for k, v in kwargs.items():
|
||||
if isinstance(v, torch.Tensor):
|
||||
v = v.detach()
|
||||
else:
|
||||
v = torch.tensor(v, dtype=torch.float32)
|
||||
|
||||
if torch.cuda.is_available():
|
||||
v = v.cuda()
|
||||
|
||||
self.metrics[k] = v
|
||||
|
||||
def aggregate(self, average: bool = True) -> dict:
|
||||
"""Aggregate metrics across all processes"""
|
||||
if not dist.is_initialized():
|
||||
return {k: v.item() if isinstance(v, torch.Tensor) else v
|
||||
for k, v in self.metrics.items()}
|
||||
|
||||
aggregated = {}
|
||||
world_size = get_world_size()
|
||||
|
||||
for k, v in self.metrics.items():
|
||||
# All-reduce
|
||||
all_reduce(v)
|
||||
|
||||
# Average if requested
|
||||
if average:
|
||||
v = v / world_size
|
||||
|
||||
aggregated[k] = v.item()
|
||||
|
||||
return aggregated
|
||||
|
||||
def reset(self):
|
||||
"""Reset metrics"""
|
||||
self.metrics = {}
|
||||
|
||||
|
||||
def save_on_master(*args, **kwargs):
|
||||
"""Decorator to only save on master process"""
|
||||
|
||||
def decorator(func):
|
||||
def wrapper(*args, **kwargs):
|
||||
if is_main_process():
|
||||
return func(*args, **kwargs)
|
||||
else:
|
||||
# Return None on non-master processes
|
||||
return None
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def log_on_master(logger: logging.Logger):
|
||||
"""Get a logger that only logs on master process"""
|
||||
|
||||
class MasterOnlyLogger:
|
||||
def __init__(self, logger):
|
||||
self.logger = logger
|
||||
|
||||
def __getattr__(self, name):
|
||||
if is_main_process():
|
||||
return getattr(self.logger, name)
|
||||
else:
|
||||
return lambda *args, **kwargs: None
|
||||
|
||||
return MasterOnlyLogger(logger)
|
||||
344
utils/logging.py
Normal file
344
utils/logging.py
Normal file
@@ -0,0 +1,344 @@
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from datetime import datetime
|
||||
from typing import Optional, Dict, Any
|
||||
import json
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
import numpy as np
|
||||
|
||||
|
||||
def create_progress_bar(
|
||||
iterable,
|
||||
desc: str = "",
|
||||
total: Optional[int] = None,
|
||||
disable: bool = False,
|
||||
position: int = 0,
|
||||
leave: bool = True,
|
||||
ncols: Optional[int] = None,
|
||||
unit: str = "it",
|
||||
dynamic_ncols: bool = True,
|
||||
):
|
||||
"""
|
||||
Create a tqdm progress bar with consistent settings across the project.
|
||||
|
||||
Args:
|
||||
iterable: Iterable to wrap with progress bar
|
||||
desc: Description for the progress bar
|
||||
total: Total number of iterations (if None, will try to get from iterable)
|
||||
disable: Whether to disable the progress bar
|
||||
position: Position of the progress bar (for multiple bars)
|
||||
leave: Whether to leave the progress bar after completion
|
||||
ncols: Width of the progress bar
|
||||
unit: Unit of iteration (e.g., "batch", "step")
|
||||
dynamic_ncols: Automatically adjust width based on terminal size
|
||||
|
||||
Returns:
|
||||
tqdm progress bar object
|
||||
"""
|
||||
# Check if we're in distributed training
|
||||
try:
|
||||
import torch.distributed as dist
|
||||
if dist.is_initialized() and dist.get_rank() != 0:
|
||||
# Disable progress bar for non-main processes
|
||||
disable = True
|
||||
except:
|
||||
pass
|
||||
|
||||
# Create progress bar with consistent settings
|
||||
return tqdm(
|
||||
iterable,
|
||||
desc=desc,
|
||||
total=total,
|
||||
disable=disable,
|
||||
position=position,
|
||||
leave=leave,
|
||||
ncols=ncols,
|
||||
unit=unit,
|
||||
dynamic_ncols=dynamic_ncols,
|
||||
# Use custom write method to avoid conflicts with logging
|
||||
file=sys.stdout,
|
||||
# Ensure smooth updates
|
||||
smoothing=0.1,
|
||||
# Format with consistent style
|
||||
bar_format='{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}{postfix}]'
|
||||
)
|
||||
|
||||
|
||||
class TqdmLoggingHandler(logging.Handler):
|
||||
"""Custom logging handler that plays nicely with tqdm progress bars"""
|
||||
|
||||
def emit(self, record):
|
||||
try:
|
||||
msg = self.format(record)
|
||||
# Use tqdm.write to avoid breaking progress bars
|
||||
tqdm.write(msg, file=sys.stdout)
|
||||
self.flush()
|
||||
except Exception:
|
||||
self.handleError(record)
|
||||
|
||||
|
||||
def setup_logging(
|
||||
output_dir: Optional[str] = None,
|
||||
log_to_file: bool = False,
|
||||
log_level: int = logging.INFO,
|
||||
log_to_console: bool = True,
|
||||
use_tqdm_handler: bool = True,
|
||||
rank: int = 0,
|
||||
world_size: int = 1,
|
||||
):
|
||||
"""
|
||||
Setup logging with optional file and console handlers.
|
||||
|
||||
Args:
|
||||
output_dir: Directory for log files
|
||||
log_to_file: Whether to log to file
|
||||
log_level: Logging level
|
||||
log_to_console: Whether to log to console
|
||||
use_tqdm_handler: Use tqdm-compatible handler for console logging
|
||||
rank: Process rank for distributed training
|
||||
world_size: World size for distributed training
|
||||
"""
|
||||
handlers: list[logging.Handler] = []
|
||||
log_path = None
|
||||
|
||||
# Only log to console from main process in distributed training
|
||||
if log_to_console and (world_size == 1 or rank == 0):
|
||||
if use_tqdm_handler:
|
||||
# Use tqdm-compatible handler
|
||||
console_handler = TqdmLoggingHandler()
|
||||
else:
|
||||
console_handler = logging.StreamHandler(sys.stdout)
|
||||
console_handler.setLevel(log_level)
|
||||
handlers.append(console_handler)
|
||||
|
||||
if log_to_file:
|
||||
if output_dir is None:
|
||||
output_dir = "./logs"
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
# Add rank to filename for distributed training
|
||||
if world_size > 1:
|
||||
log_filename = f"train_rank{rank}.log"
|
||||
else:
|
||||
log_filename = "train.log"
|
||||
|
||||
log_path = os.path.join(output_dir, log_filename)
|
||||
file_handler = logging.FileHandler(log_path)
|
||||
file_handler.setLevel(log_level)
|
||||
handlers.append(file_handler)
|
||||
|
||||
# Create formatter
|
||||
if world_size > 1:
|
||||
# Include rank in log format for distributed training
|
||||
formatter = logging.Formatter(
|
||||
f'[%(asctime)s][Rank {rank}] %(levelname)s %(name)s: %(message)s',
|
||||
datefmt='%Y-%m-%d %H:%M:%S'
|
||||
)
|
||||
else:
|
||||
formatter = logging.Formatter(
|
||||
'[%(asctime)s] %(levelname)s %(name)s: %(message)s',
|
||||
datefmt='%Y-%m-%d %H:%M:%S'
|
||||
)
|
||||
|
||||
# Apply formatter to all handlers
|
||||
for handler in handlers:
|
||||
handler.setFormatter(formatter)
|
||||
|
||||
# Configure root logger
|
||||
root_logger = logging.getLogger()
|
||||
root_logger.setLevel(log_level)
|
||||
|
||||
# Remove existing handlers to avoid duplicates
|
||||
for handler in root_logger.handlers[:]:
|
||||
root_logger.removeHandler(handler)
|
||||
|
||||
# Add new handlers
|
||||
for handler in handlers:
|
||||
root_logger.addHandler(handler)
|
||||
|
||||
# Log initialization message
|
||||
if rank == 0:
|
||||
logging.info(
|
||||
f"Logging initialized. Level: {logging.getLevelName(log_level)} | "
|
||||
f"Log file: {log_path if log_to_file else 'none'} | "
|
||||
f"World size: {world_size}"
|
||||
)
|
||||
|
||||
|
||||
def get_logger(name: str) -> logging.Logger:
|
||||
"""Get a logger with the specified name"""
|
||||
return logging.getLogger(name)
|
||||
|
||||
|
||||
class MetricsLogger:
|
||||
"""Logger specifically for training metrics"""
|
||||
|
||||
def __init__(self, output_dir: str):
|
||||
self.output_dir = output_dir
|
||||
self.metrics_file = os.path.join(output_dir, 'metrics.jsonl')
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
def log(self, step: int, metrics: Dict[str, Any], phase: str = 'train'):
|
||||
"""Log metrics to file"""
|
||||
entry = {
|
||||
'step': step,
|
||||
'phase': phase,
|
||||
'timestamp': datetime.now().isoformat(),
|
||||
'metrics': metrics
|
||||
}
|
||||
|
||||
with open(self.metrics_file, 'a') as f:
|
||||
f.write(json.dumps(entry) + '\n')
|
||||
|
||||
def log_config(self, config: Dict[str, Any]):
|
||||
"""Log configuration"""
|
||||
config_file = os.path.join(self.output_dir, 'config.json')
|
||||
with open(config_file, 'w') as f:
|
||||
json.dump(config, f, indent=2)
|
||||
|
||||
def log_summary(self, summary: Dict[str, Any]):
|
||||
"""Log training summary"""
|
||||
summary_file = os.path.join(self.output_dir, 'summary.json')
|
||||
with open(summary_file, 'w') as f:
|
||||
json.dump(summary, f, indent=2)
|
||||
|
||||
|
||||
class TensorBoardLogger:
|
||||
"""TensorBoard logging wrapper"""
|
||||
|
||||
def __init__(self, log_dir: str, enabled: bool = True):
|
||||
self.enabled = enabled
|
||||
if self.enabled:
|
||||
try:
|
||||
from torch.utils.tensorboard.writer import SummaryWriter
|
||||
self.writer = SummaryWriter(log_dir)
|
||||
except ImportError:
|
||||
get_logger(__name__).warning(
|
||||
"TensorBoard not available. Install with: pip install tensorboard"
|
||||
)
|
||||
self.enabled = False
|
||||
|
||||
def log_scalar(self, tag: str, value: float, step: int):
|
||||
"""Log a scalar value"""
|
||||
if self.enabled:
|
||||
self.writer.add_scalar(tag, value, step)
|
||||
|
||||
def log_scalars(self, main_tag: str, tag_scalar_dict: Dict[str, float], step: int):
|
||||
"""Log multiple scalars"""
|
||||
if self.enabled:
|
||||
self.writer.add_scalars(main_tag, tag_scalar_dict, step)
|
||||
|
||||
def log_histogram(self, tag: str, values: np.ndarray, step: int):
|
||||
"""Log a histogram"""
|
||||
if self.enabled:
|
||||
self.writer.add_histogram(tag, values, step)
|
||||
|
||||
def log_embedding(self, embeddings: torch.Tensor, metadata: Optional[list] = None, step: int = 0):
|
||||
"""Log embeddings for visualization"""
|
||||
if self.enabled:
|
||||
self.writer.add_embedding(embeddings, metadata=metadata, global_step=step)
|
||||
|
||||
def flush(self):
|
||||
"""Flush the writer"""
|
||||
if self.enabled:
|
||||
self.writer.flush()
|
||||
|
||||
def close(self):
|
||||
"""Close the writer"""
|
||||
if self.enabled:
|
||||
self.writer.close()
|
||||
|
||||
|
||||
class ProgressLogger:
|
||||
"""Helper for logging training progress"""
|
||||
|
||||
def __init__(self, total_steps: int, log_interval: int = 100):
|
||||
self.total_steps = total_steps
|
||||
self.log_interval = log_interval
|
||||
self.logger = get_logger(__name__)
|
||||
self.start_time = datetime.now()
|
||||
|
||||
def log_step(self, step: int, loss: float, metrics: Optional[Dict[str, float]] = None):
|
||||
"""Log progress for a training step"""
|
||||
if step % self.log_interval == 0:
|
||||
elapsed = (datetime.now() - self.start_time).total_seconds()
|
||||
steps_per_sec = step / elapsed if elapsed > 0 else 0
|
||||
eta_seconds = (self.total_steps - step) / steps_per_sec if steps_per_sec > 0 else 0
|
||||
|
||||
msg = f"Step {step}/{self.total_steps} | Loss: {loss:.4f}"
|
||||
|
||||
if metrics:
|
||||
metric_str = " | ".join([f"{k}: {v:.4f}" for k, v in metrics.items()])
|
||||
msg += f" | {metric_str}"
|
||||
|
||||
msg += f" | Speed: {steps_per_sec:.2f} steps/s | ETA: {self._format_time(eta_seconds)}"
|
||||
|
||||
self.logger.info(msg)
|
||||
|
||||
def _format_time(self, seconds: float) -> str:
|
||||
"""Format seconds to human readable time"""
|
||||
hours = int(seconds // 3600)
|
||||
minutes = int((seconds % 3600) // 60)
|
||||
seconds = int(seconds % 60)
|
||||
|
||||
if hours > 0:
|
||||
return f"{hours}h {minutes}m {seconds}s"
|
||||
elif minutes > 0:
|
||||
return f"{minutes}m {seconds}s"
|
||||
else:
|
||||
return f"{seconds}s"
|
||||
|
||||
|
||||
def log_model_info(model: torch.nn.Module, logger: Optional[logging.Logger] = None):
|
||||
"""Log model information"""
|
||||
if logger is None:
|
||||
logger = get_logger(__name__)
|
||||
|
||||
# Count parameters
|
||||
total_params = sum(p.numel() for p in model.parameters())
|
||||
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||
|
||||
logger.info(f"Model: {model.__class__.__name__}")
|
||||
logger.info(f"Total parameters: {total_params:,}")
|
||||
logger.info(f"Trainable parameters: {trainable_params:,}")
|
||||
logger.info(f"Non-trainable parameters: {total_params - trainable_params:,}")
|
||||
|
||||
# Log parameter shapes
|
||||
logger.debug("Parameter shapes:")
|
||||
for name, param in model.named_parameters():
|
||||
logger.debug(f" {name}: {list(param.shape)}")
|
||||
|
||||
|
||||
def log_training_summary(
|
||||
config: Dict[str, Any],
|
||||
metrics: Dict[str, float],
|
||||
elapsed_time: float,
|
||||
output_dir: str
|
||||
):
|
||||
"""Log comprehensive training summary"""
|
||||
logger = get_logger(__name__)
|
||||
|
||||
summary = {
|
||||
'config': config,
|
||||
'final_metrics': metrics,
|
||||
'training_time': elapsed_time,
|
||||
'timestamp': datetime.now().isoformat()
|
||||
}
|
||||
|
||||
# Save to file
|
||||
summary_file = os.path.join(output_dir, 'training_summary.json')
|
||||
with open(summary_file, 'w') as f:
|
||||
json.dump(summary, f, indent=2)
|
||||
|
||||
# Log to console
|
||||
logger.info("=" * 50)
|
||||
logger.info("TRAINING SUMMARY")
|
||||
logger.info("=" * 50)
|
||||
logger.info(f"Training time: {elapsed_time:.2f} seconds")
|
||||
logger.info("Final metrics:")
|
||||
for metric, value in metrics.items():
|
||||
logger.info(f" {metric}: {value:.4f}")
|
||||
logger.info(f"Summary saved to: {summary_file}")
|
||||
logger.info("=" * 50)
|
||||
Reference in New Issue
Block a user