128 lines
3.5 KiB
Python
128 lines
3.5 KiB
Python
"""
|
|
Utility modules for logging, checkpointing, and distributed training
|
|
"""
|
|
from .logging import (
|
|
setup_logging,
|
|
get_logger,
|
|
MetricsLogger,
|
|
TensorBoardLogger,
|
|
ProgressLogger,
|
|
log_model_info,
|
|
log_training_summary
|
|
)
|
|
from .checkpoint import CheckpointManager, save_model_for_inference
|
|
from .distributed import (
|
|
setup_distributed,
|
|
cleanup_distributed,
|
|
is_main_process,
|
|
get_rank,
|
|
get_world_size,
|
|
get_local_rank,
|
|
synchronize,
|
|
all_reduce,
|
|
all_gather,
|
|
broadcast,
|
|
reduce_dict,
|
|
setup_model_for_distributed,
|
|
create_distributed_dataloader,
|
|
set_random_seed,
|
|
print_distributed_info,
|
|
DistributedMetricAggregator,
|
|
save_on_master,
|
|
log_on_master
|
|
)
|
|
import os
|
|
import sys
|
|
import logging
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def optional_import(module_name, feature_desc=None, warn_only=True):
|
|
"""Try to import an optional module. Log a warning if not found. Returns the module or None."""
|
|
import importlib
|
|
try:
|
|
return importlib.import_module(module_name)
|
|
except ImportError:
|
|
msg = f"Optional dependency '{module_name}' not found."
|
|
if feature_desc:
|
|
msg += f" {feature_desc} will be unavailable."
|
|
if warn_only:
|
|
logger.warning(msg)
|
|
else:
|
|
logger.error(msg)
|
|
return None
|
|
|
|
# Import guard for optional dependencies
|
|
|
|
def _check_optional_dependencies():
|
|
"""Check for optional dependencies and log warnings if missing when features are likely to be used."""
|
|
# Check for torch_npu if running on Ascend
|
|
torch = optional_import('torch', warn_only=True)
|
|
if torch and ('npu' in str(getattr(torch, 'device', lambda: '')()).lower() or os.environ.get('DEVICE', '').lower() == 'npu'):
|
|
optional_import('torch_npu', feature_desc='Ascend NPU support')
|
|
|
|
# Check for wandb if Weights & Biases logging is enabled
|
|
try:
|
|
from utils.config_loader import get_config
|
|
config = get_config()
|
|
use_wandb = False
|
|
try:
|
|
use_wandb = config.get('logging.use_wandb', False) # type: ignore[arg-type]
|
|
except Exception:
|
|
pass
|
|
if use_wandb:
|
|
optional_import('wandb', feature_desc='Weights & Biases logging')
|
|
except Exception:
|
|
pass
|
|
|
|
# Check for tensorboard if tensorboard logging is likely to be used
|
|
tensorboard_dir = None
|
|
try:
|
|
tensorboard_dir = config.get('logging.tensorboard_dir', None) # type: ignore[arg-type]
|
|
except Exception:
|
|
pass
|
|
if tensorboard_dir:
|
|
optional_import('tensorboard', feature_desc='TensorBoard logging')
|
|
|
|
# Check for pytest if running tests
|
|
import sys
|
|
if any('pytest' in arg for arg in sys.argv):
|
|
optional_import('pytest', feature_desc='Some tests may not run')
|
|
|
|
|
|
__all__ = [
|
|
# Logging
|
|
'setup_logging',
|
|
'get_logger',
|
|
'MetricsLogger',
|
|
'TensorBoardLogger',
|
|
'ProgressLogger',
|
|
'log_model_info',
|
|
'log_training_summary',
|
|
|
|
# Checkpointing
|
|
'CheckpointManager',
|
|
'save_model_for_inference',
|
|
|
|
# Distributed training
|
|
'setup_distributed',
|
|
'cleanup_distributed',
|
|
'is_main_process',
|
|
'get_rank',
|
|
'get_world_size',
|
|
'get_local_rank',
|
|
'synchronize',
|
|
'all_reduce',
|
|
'all_gather',
|
|
'broadcast',
|
|
'reduce_dict',
|
|
'setup_model_for_distributed',
|
|
'create_distributed_dataloader',
|
|
'set_random_seed',
|
|
'print_distributed_info',
|
|
'DistributedMetricAggregator',
|
|
'save_on_master',
|
|
'log_on_master'
|
|
]
|