bge_finetune/utils/__init__.py
2025-07-22 16:55:25 +08:00

128 lines
3.5 KiB
Python

"""
Utility modules for logging, checkpointing, and distributed training
"""
from .logging import (
setup_logging,
get_logger,
MetricsLogger,
TensorBoardLogger,
ProgressLogger,
log_model_info,
log_training_summary
)
from .checkpoint import CheckpointManager, save_model_for_inference
from .distributed import (
setup_distributed,
cleanup_distributed,
is_main_process,
get_rank,
get_world_size,
get_local_rank,
synchronize,
all_reduce,
all_gather,
broadcast,
reduce_dict,
setup_model_for_distributed,
create_distributed_dataloader,
set_random_seed,
print_distributed_info,
DistributedMetricAggregator,
save_on_master,
log_on_master
)
import os
import sys
import logging
logger = logging.getLogger(__name__)
def optional_import(module_name, feature_desc=None, warn_only=True):
"""Try to import an optional module. Log a warning if not found. Returns the module or None."""
import importlib
try:
return importlib.import_module(module_name)
except ImportError:
msg = f"Optional dependency '{module_name}' not found."
if feature_desc:
msg += f" {feature_desc} will be unavailable."
if warn_only:
logger.warning(msg)
else:
logger.error(msg)
return None
# Import guard for optional dependencies
def _check_optional_dependencies():
"""Check for optional dependencies and log warnings if missing when features are likely to be used."""
# Check for torch_npu if running on Ascend
torch = optional_import('torch', warn_only=True)
if torch and ('npu' in str(getattr(torch, 'device', lambda: '')()).lower() or os.environ.get('DEVICE', '').lower() == 'npu'):
optional_import('torch_npu', feature_desc='Ascend NPU support')
# Check for wandb if Weights & Biases logging is enabled
try:
from utils.config_loader import get_config
config = get_config()
use_wandb = False
try:
use_wandb = config.get('logging.use_wandb', False) # type: ignore[arg-type]
except Exception:
pass
if use_wandb:
optional_import('wandb', feature_desc='Weights & Biases logging')
except Exception:
pass
# Check for tensorboard if tensorboard logging is likely to be used
tensorboard_dir = None
try:
tensorboard_dir = config.get('logging.tensorboard_dir', None) # type: ignore[arg-type]
except Exception:
pass
if tensorboard_dir:
optional_import('tensorboard', feature_desc='TensorBoard logging')
# Check for pytest if running tests
import sys
if any('pytest' in arg for arg in sys.argv):
optional_import('pytest', feature_desc='Some tests may not run')
__all__ = [
# Logging
'setup_logging',
'get_logger',
'MetricsLogger',
'TensorBoardLogger',
'ProgressLogger',
'log_model_info',
'log_training_summary',
# Checkpointing
'CheckpointManager',
'save_model_for_inference',
# Distributed training
'setup_distributed',
'cleanup_distributed',
'is_main_process',
'get_rank',
'get_world_size',
'get_local_rank',
'synchronize',
'all_reduce',
'all_gather',
'broadcast',
'reduce_dict',
'setup_model_for_distributed',
'create_distributed_dataloader',
'set_random_seed',
'print_distributed_info',
'DistributedMetricAggregator',
'save_on_master',
'log_on_master'
]