init commit
This commit is contained in:
127
utils/__init__.py
Normal file
127
utils/__init__.py
Normal file
@@ -0,0 +1,127 @@
|
||||
"""
|
||||
Utility modules for logging, checkpointing, and distributed training
|
||||
"""
|
||||
from .logging import (
|
||||
setup_logging,
|
||||
get_logger,
|
||||
MetricsLogger,
|
||||
TensorBoardLogger,
|
||||
ProgressLogger,
|
||||
log_model_info,
|
||||
log_training_summary
|
||||
)
|
||||
from .checkpoint import CheckpointManager, save_model_for_inference
|
||||
from .distributed import (
|
||||
setup_distributed,
|
||||
cleanup_distributed,
|
||||
is_main_process,
|
||||
get_rank,
|
||||
get_world_size,
|
||||
get_local_rank,
|
||||
synchronize,
|
||||
all_reduce,
|
||||
all_gather,
|
||||
broadcast,
|
||||
reduce_dict,
|
||||
setup_model_for_distributed,
|
||||
create_distributed_dataloader,
|
||||
set_random_seed,
|
||||
print_distributed_info,
|
||||
DistributedMetricAggregator,
|
||||
save_on_master,
|
||||
log_on_master
|
||||
)
|
||||
import os
|
||||
import sys
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def optional_import(module_name, feature_desc=None, warn_only=True):
|
||||
"""Try to import an optional module. Log a warning if not found. Returns the module or None."""
|
||||
import importlib
|
||||
try:
|
||||
return importlib.import_module(module_name)
|
||||
except ImportError:
|
||||
msg = f"Optional dependency '{module_name}' not found."
|
||||
if feature_desc:
|
||||
msg += f" {feature_desc} will be unavailable."
|
||||
if warn_only:
|
||||
logger.warning(msg)
|
||||
else:
|
||||
logger.error(msg)
|
||||
return None
|
||||
|
||||
# Import guard for optional dependencies
|
||||
|
||||
def _check_optional_dependencies():
|
||||
"""Check for optional dependencies and log warnings if missing when features are likely to be used."""
|
||||
# Check for torch_npu if running on Ascend
|
||||
torch = optional_import('torch', warn_only=True)
|
||||
if torch and ('npu' in str(getattr(torch, 'device', lambda: '')()).lower() or os.environ.get('DEVICE', '').lower() == 'npu'):
|
||||
optional_import('torch_npu', feature_desc='Ascend NPU support')
|
||||
|
||||
# Check for wandb if Weights & Biases logging is enabled
|
||||
try:
|
||||
from utils.config_loader import get_config
|
||||
config = get_config()
|
||||
use_wandb = False
|
||||
try:
|
||||
use_wandb = config.get('logging.use_wandb', False) # type: ignore[arg-type]
|
||||
except Exception:
|
||||
pass
|
||||
if use_wandb:
|
||||
optional_import('wandb', feature_desc='Weights & Biases logging')
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Check for tensorboard if tensorboard logging is likely to be used
|
||||
tensorboard_dir = None
|
||||
try:
|
||||
tensorboard_dir = config.get('logging.tensorboard_dir', None) # type: ignore[arg-type]
|
||||
except Exception:
|
||||
pass
|
||||
if tensorboard_dir:
|
||||
optional_import('tensorboard', feature_desc='TensorBoard logging')
|
||||
|
||||
# Check for pytest if running tests
|
||||
import sys
|
||||
if any('pytest' in arg for arg in sys.argv):
|
||||
optional_import('pytest', feature_desc='Some tests may not run')
|
||||
|
||||
|
||||
__all__ = [
|
||||
# Logging
|
||||
'setup_logging',
|
||||
'get_logger',
|
||||
'MetricsLogger',
|
||||
'TensorBoardLogger',
|
||||
'ProgressLogger',
|
||||
'log_model_info',
|
||||
'log_training_summary',
|
||||
|
||||
# Checkpointing
|
||||
'CheckpointManager',
|
||||
'save_model_for_inference',
|
||||
|
||||
# Distributed training
|
||||
'setup_distributed',
|
||||
'cleanup_distributed',
|
||||
'is_main_process',
|
||||
'get_rank',
|
||||
'get_world_size',
|
||||
'get_local_rank',
|
||||
'synchronize',
|
||||
'all_reduce',
|
||||
'all_gather',
|
||||
'broadcast',
|
||||
'reduce_dict',
|
||||
'setup_model_for_distributed',
|
||||
'create_distributed_dataloader',
|
||||
'set_random_seed',
|
||||
'print_distributed_info',
|
||||
'DistributedMetricAggregator',
|
||||
'save_on_master',
|
||||
'log_on_master'
|
||||
]
|
||||
Reference in New Issue
Block a user