""" Utility modules for logging, checkpointing, and distributed training """ from .logging import ( setup_logging, get_logger, MetricsLogger, TensorBoardLogger, ProgressLogger, log_model_info, log_training_summary ) from .checkpoint import CheckpointManager, save_model_for_inference from .distributed import ( setup_distributed, cleanup_distributed, is_main_process, get_rank, get_world_size, get_local_rank, synchronize, all_reduce, all_gather, broadcast, reduce_dict, setup_model_for_distributed, create_distributed_dataloader, set_random_seed, print_distributed_info, DistributedMetricAggregator, save_on_master, log_on_master ) import os import sys import logging logger = logging.getLogger(__name__) def optional_import(module_name, feature_desc=None, warn_only=True): """Try to import an optional module. Log a warning if not found. Returns the module or None.""" import importlib try: return importlib.import_module(module_name) except ImportError: msg = f"Optional dependency '{module_name}' not found." if feature_desc: msg += f" {feature_desc} will be unavailable." if warn_only: logger.warning(msg) else: logger.error(msg) return None # Import guard for optional dependencies def _check_optional_dependencies(): """Check for optional dependencies and log warnings if missing when features are likely to be used.""" # Check for torch_npu if running on Ascend torch = optional_import('torch', warn_only=True) if torch and ('npu' in str(getattr(torch, 'device', lambda: '')()).lower() or os.environ.get('DEVICE', '').lower() == 'npu'): optional_import('torch_npu', feature_desc='Ascend NPU support') # Check for wandb if Weights & Biases logging is enabled try: from utils.config_loader import get_config config = get_config() use_wandb = False try: use_wandb = config.get('logging.use_wandb', False) # type: ignore[arg-type] except Exception: pass if use_wandb: optional_import('wandb', feature_desc='Weights & Biases logging') except Exception: pass # Check for tensorboard if tensorboard logging is likely to be used tensorboard_dir = None try: tensorboard_dir = config.get('logging.tensorboard_dir', None) # type: ignore[arg-type] except Exception: pass if tensorboard_dir: optional_import('tensorboard', feature_desc='TensorBoard logging') # Check for pytest if running tests import sys if any('pytest' in arg for arg in sys.argv): optional_import('pytest', feature_desc='Some tests may not run') __all__ = [ # Logging 'setup_logging', 'get_logger', 'MetricsLogger', 'TensorBoardLogger', 'ProgressLogger', 'log_model_info', 'log_training_summary', # Checkpointing 'CheckpointManager', 'save_model_for_inference', # Distributed training 'setup_distributed', 'cleanup_distributed', 'is_main_process', 'get_rank', 'get_world_size', 'get_local_rank', 'synchronize', 'all_reduce', 'all_gather', 'broadcast', 'reduce_dict', 'setup_model_for_distributed', 'create_distributed_dataloader', 'set_random_seed', 'print_distributed_info', 'DistributedMetricAggregator', 'save_on_master', 'log_on_master' ]