init commit

This commit is contained in:
ldy
2025-07-22 16:55:25 +08:00
commit 36003b83e2
67 changed files with 76613 additions and 0 deletions

127
utils/__init__.py Normal file
View File

@@ -0,0 +1,127 @@
"""
Utility modules for logging, checkpointing, and distributed training
"""
from .logging import (
setup_logging,
get_logger,
MetricsLogger,
TensorBoardLogger,
ProgressLogger,
log_model_info,
log_training_summary
)
from .checkpoint import CheckpointManager, save_model_for_inference
from .distributed import (
setup_distributed,
cleanup_distributed,
is_main_process,
get_rank,
get_world_size,
get_local_rank,
synchronize,
all_reduce,
all_gather,
broadcast,
reduce_dict,
setup_model_for_distributed,
create_distributed_dataloader,
set_random_seed,
print_distributed_info,
DistributedMetricAggregator,
save_on_master,
log_on_master
)
import os
import sys
import logging
logger = logging.getLogger(__name__)
def optional_import(module_name, feature_desc=None, warn_only=True):
"""Try to import an optional module. Log a warning if not found. Returns the module or None."""
import importlib
try:
return importlib.import_module(module_name)
except ImportError:
msg = f"Optional dependency '{module_name}' not found."
if feature_desc:
msg += f" {feature_desc} will be unavailable."
if warn_only:
logger.warning(msg)
else:
logger.error(msg)
return None
# Import guard for optional dependencies
def _check_optional_dependencies():
"""Check for optional dependencies and log warnings if missing when features are likely to be used."""
# Check for torch_npu if running on Ascend
torch = optional_import('torch', warn_only=True)
if torch and ('npu' in str(getattr(torch, 'device', lambda: '')()).lower() or os.environ.get('DEVICE', '').lower() == 'npu'):
optional_import('torch_npu', feature_desc='Ascend NPU support')
# Check for wandb if Weights & Biases logging is enabled
try:
from utils.config_loader import get_config
config = get_config()
use_wandb = False
try:
use_wandb = config.get('logging.use_wandb', False) # type: ignore[arg-type]
except Exception:
pass
if use_wandb:
optional_import('wandb', feature_desc='Weights & Biases logging')
except Exception:
pass
# Check for tensorboard if tensorboard logging is likely to be used
tensorboard_dir = None
try:
tensorboard_dir = config.get('logging.tensorboard_dir', None) # type: ignore[arg-type]
except Exception:
pass
if tensorboard_dir:
optional_import('tensorboard', feature_desc='TensorBoard logging')
# Check for pytest if running tests
import sys
if any('pytest' in arg for arg in sys.argv):
optional_import('pytest', feature_desc='Some tests may not run')
__all__ = [
# Logging
'setup_logging',
'get_logger',
'MetricsLogger',
'TensorBoardLogger',
'ProgressLogger',
'log_model_info',
'log_training_summary',
# Checkpointing
'CheckpointManager',
'save_model_for_inference',
# Distributed training
'setup_distributed',
'cleanup_distributed',
'is_main_process',
'get_rank',
'get_world_size',
'get_local_rank',
'synchronize',
'all_reduce',
'all_gather',
'broadcast',
'reduce_dict',
'setup_model_for_distributed',
'create_distributed_dataloader',
'set_random_seed',
'print_distributed_info',
'DistributedMetricAggregator',
'save_on_master',
'log_on_master'
]

271
utils/ascend_support.py Normal file
View File

@@ -0,0 +1,271 @@
"""
Huawei Ascend NPU support utilities for BGE fine-tuning
This module provides compatibility layer for running BGE models on Ascend NPUs.
Requires torch_npu to be installed: pip install torch_npu
"""
import os
import logging
from typing import Optional, Dict, Any, Tuple
import torch
logger = logging.getLogger(__name__)
# Global flag to track Ascend availability
_ASCEND_AVAILABLE = False
_ASCEND_CHECK_DONE = False
def check_ascend_availability() -> bool:
"""
Check if Ascend NPU is available
Returns:
bool: True if Ascend NPU is available and properly configured
"""
global _ASCEND_AVAILABLE, _ASCEND_CHECK_DONE
if _ASCEND_CHECK_DONE:
return _ASCEND_AVAILABLE
_ASCEND_CHECK_DONE = True
try:
import torch_npu
# Check if NPU is available
if hasattr(torch, 'npu') and torch.npu.is_available():
_ASCEND_AVAILABLE = True
npu_count = torch.npu.device_count()
logger.info(f"Ascend NPU available with {npu_count} devices")
# Log NPU information
for i in range(npu_count):
props = torch.npu.get_device_properties(i)
logger.info(f"NPU {i}: {props.name}, Total Memory: {props.total_memory / 1024**3:.2f} GB")
else:
logger.info("Ascend NPU not available")
except ImportError:
logger.debug("torch_npu not installed. Ascend NPU support disabled.")
except Exception as e:
logger.warning(f"Error checking Ascend availability: {e}")
return _ASCEND_AVAILABLE
def get_ascend_backend() -> str:
"""
Get the appropriate backend for Ascend NPU
Returns:
str: 'hccl' for Ascend, 'nccl' for CUDA, 'gloo' for CPU
"""
if check_ascend_availability():
return 'hccl'
elif torch.cuda.is_available():
return 'nccl'
else:
return 'gloo'
def setup_ascend_device(local_rank: int = 0) -> torch.device:
"""
Setup Ascend NPU device
Args:
local_rank: Local rank for distributed training
Returns:
torch.device: Configured device
"""
if not check_ascend_availability():
raise RuntimeError("Ascend NPU not available")
try:
# Set NPU device
torch.npu.set_device(local_rank)
device = torch.device(f'npu:{local_rank}')
logger.info(f"Using Ascend NPU device: {device}")
return device
except Exception as e:
logger.error(f"Failed to setup Ascend device: {e}")
raise
def convert_model_for_ascend(model: torch.nn.Module) -> torch.nn.Module:
"""
Convert model for Ascend NPU compatibility
Args:
model: PyTorch model
Returns:
Model configured for Ascend
"""
if not check_ascend_availability():
return model
try:
import torch_npu
# Apply Ascend-specific optimizations
model = torch_npu.contrib.transfer_to_npu(model)
# Enable Ascend graph mode if available
if hasattr(torch_npu, 'enable_graph_mode'):
torch_npu.enable_graph_mode()
logger.info("Enabled Ascend graph mode")
return model
except Exception as e:
logger.warning(f"Failed to apply Ascend optimizations: {e}")
return model
def optimize_for_ascend(config: Dict[str, Any]) -> Dict[str, Any]:
"""
Optimize configuration for Ascend NPU
Args:
config: Training configuration
Returns:
Optimized configuration
"""
if not check_ascend_availability():
return config
# Create a copy to avoid modifying original
config = config.copy()
# Ascend-specific optimizations
ascend_config = config.get('ascend', {})
# Set backend to hccl for distributed training
if ascend_config.get('backend') == 'hccl':
config['distributed'] = config.get('distributed', {})
config['distributed']['backend'] = 'hccl'
# Enable mixed precision if specified
if ascend_config.get('mixed_precision', False):
config['optimization'] = config.get('optimization', {})
config['optimization']['fp16'] = True
logger.info("Enabled mixed precision for Ascend")
# Set visible devices
if 'visible_devices' in ascend_config:
os.environ['ASCEND_RT_VISIBLE_DEVICES'] = ascend_config['visible_devices']
logger.info(f"Set ASCEND_RT_VISIBLE_DEVICES={ascend_config['visible_devices']}")
return config
def get_ascend_memory_info(device_id: int = 0) -> Tuple[float, float]:
"""
Get Ascend NPU memory information
Args:
device_id: NPU device ID
Returns:
Tuple of (used_memory_gb, total_memory_gb)
"""
if not check_ascend_availability():
return 0.0, 0.0
try:
import torch_npu
# Get memory info
used = torch_npu.memory_allocated(device_id) / 1024**3
total = torch_npu.get_device_properties(device_id).total_memory / 1024**3
return used, total
except Exception as e:
logger.warning(f"Failed to get Ascend memory info: {e}")
return 0.0, 0.0
def clear_ascend_cache():
"""Clear Ascend NPU memory cache"""
if not check_ascend_availability():
return
try:
torch.npu.empty_cache()
logger.debug("Cleared Ascend NPU cache")
except Exception as e:
logger.warning(f"Failed to clear Ascend cache: {e}")
def is_ascend_available() -> bool:
"""
Simple check for Ascend availability
Returns:
bool: True if Ascend NPU is available
"""
return check_ascend_availability()
class AscendDataLoader:
"""
Wrapper for DataLoader with Ascend-specific optimizations
"""
def __init__(self, dataloader, device):
self.dataloader = dataloader
self.device = device
def __iter__(self):
for batch in self.dataloader:
# Move batch to NPU device
if isinstance(batch, dict):
batch = {k: v.to(self.device) if isinstance(v, torch.Tensor) else v
for k, v in batch.items()}
elif isinstance(batch, (list, tuple)):
batch = [v.to(self.device) if isinstance(v, torch.Tensor) else v
for v in batch]
else:
batch = batch.to(self.device)
yield batch
def __len__(self):
return len(self.dataloader)
# Compatibility functions
def npu_synchronize():
"""Synchronize NPU device (compatibility wrapper)"""
if check_ascend_availability():
try:
torch.npu.synchronize()
except Exception as e:
logger.warning(f"NPU synchronize failed: {e}")
elif torch.cuda.is_available():
torch.cuda.synchronize()
def get_device_name(device: torch.device) -> str:
"""
Get device name with Ascend support
Args:
device: PyTorch device
Returns:
str: Device name
"""
if device.type == 'npu':
if check_ascend_availability():
try:
props = torch.npu.get_device_properties(device.index or 0)
return f"Ascend {props.name}"
except:
return "Ascend NPU"
return "NPU (unavailable)"
elif device.type == 'cuda':
return torch.cuda.get_device_name(device.index or 0)
else:
return str(device)

465
utils/checkpoint.py Normal file
View File

@@ -0,0 +1,465 @@
import os
import torch
import json
import shutil
from typing import Dict, Any, Optional, List
import logging
from datetime import datetime
import glob
logger = logging.getLogger(__name__)
class CheckpointManager:
"""
Manages model checkpoints with automatic cleanup and best model tracking
"""
def __init__(
self,
output_dir: str,
save_total_limit: int = 3,
best_model_dir: str = "best_model",
save_optimizer: bool = True,
save_scheduler: bool = True
):
self.output_dir = output_dir
self.save_total_limit = save_total_limit
self.best_model_dir = os.path.join(output_dir, best_model_dir)
self.save_optimizer = save_optimizer
self.save_scheduler = save_scheduler
# Create directories
os.makedirs(output_dir, exist_ok=True)
os.makedirs(self.best_model_dir, exist_ok=True)
# Track checkpoints
self.checkpoints = []
self._load_existing_checkpoints()
def _load_existing_checkpoints(self):
"""Load list of existing checkpoints"""
checkpoint_dirs = glob.glob(os.path.join(self.output_dir, "checkpoint-*"))
# Sort by step number
self.checkpoints = []
for cp_dir in checkpoint_dirs:
try:
step = int(cp_dir.split("-")[-1])
self.checkpoints.append((step, cp_dir))
except:
pass
self.checkpoints.sort(key=lambda x: x[0])
def save(
self,
state_dict: Dict[str, Any],
step: int,
is_best: bool = False,
metrics: Optional[Dict[str, float]] = None
) -> str:
"""
Save checkpoint with atomic writes and error handling
Args:
state_dict: State dictionary containing model, optimizer, etc.
step: Current training step
is_best: Whether this is the best model so far
metrics: Optional metrics to save with checkpoint
Returns:
Path to saved checkpoint
"""
checkpoint_dir = os.path.join(self.output_dir, f"checkpoint-{step}")
temp_dir = os.path.join(self.output_dir, f"checkpoint-{step}-tmp")
try:
# Create temporary directory for atomic write
os.makedirs(temp_dir, exist_ok=True)
# Save individual components
checkpoint_files = {}
# Save model with error handling
if 'model_state_dict' in state_dict:
model_path = os.path.join(temp_dir, 'pytorch_model.bin')
self._safe_torch_save(state_dict['model_state_dict'], model_path)
checkpoint_files['model'] = 'pytorch_model.bin'
# Save optimizer
if self.save_optimizer and 'optimizer_state_dict' in state_dict:
optimizer_path = os.path.join(temp_dir, 'optimizer.pt')
self._safe_torch_save(state_dict['optimizer_state_dict'], optimizer_path)
checkpoint_files['optimizer'] = 'optimizer.pt'
# Save scheduler
if self.save_scheduler and 'scheduler_state_dict' in state_dict:
scheduler_path = os.path.join(temp_dir, 'scheduler.pt')
self._safe_torch_save(state_dict['scheduler_state_dict'], scheduler_path)
checkpoint_files['scheduler'] = 'scheduler.pt'
# Save training state
training_state = {
'step': step,
'checkpoint_files': checkpoint_files,
'timestamp': datetime.now().isoformat()
}
# Add additional state info
for key in ['global_step', 'current_epoch', 'best_metric']:
if key in state_dict:
training_state[key] = state_dict[key]
if metrics:
training_state['metrics'] = metrics
# Save config if provided
if 'config' in state_dict:
config_path = os.path.join(temp_dir, 'config.json')
self._safe_json_save(state_dict['config'], config_path)
# Save training state
state_path = os.path.join(temp_dir, 'training_state.json')
self._safe_json_save(training_state, state_path)
# Atomic rename from temp to final directory
if os.path.exists(checkpoint_dir):
# Backup existing checkpoint
backup_dir = f"{checkpoint_dir}.backup"
if os.path.exists(backup_dir):
shutil.rmtree(backup_dir)
shutil.move(checkpoint_dir, backup_dir)
# Move temp directory to final location
shutil.move(temp_dir, checkpoint_dir)
# Remove backup if successful
if os.path.exists(f"{checkpoint_dir}.backup"):
shutil.rmtree(f"{checkpoint_dir}.backup")
# Add to checkpoint list
self.checkpoints.append((step, checkpoint_dir))
# Save best model
if is_best:
self._save_best_model(checkpoint_dir, step, metrics)
# Cleanup old checkpoints
self._cleanup_checkpoints()
logger.info(f"Saved checkpoint to {checkpoint_dir}")
return checkpoint_dir
except Exception as e:
logger.error(f"Failed to save checkpoint: {e}")
# Cleanup temp directory
if os.path.exists(temp_dir):
shutil.rmtree(temp_dir)
# Restore backup if exists
backup_dir = f"{checkpoint_dir}.backup"
if os.path.exists(backup_dir):
if os.path.exists(checkpoint_dir):
shutil.rmtree(checkpoint_dir)
shutil.move(backup_dir, checkpoint_dir)
raise RuntimeError(f"Checkpoint save failed: {e}")
def _save_best_model(
self,
checkpoint_dir: str,
step: int,
metrics: Optional[Dict[str, float]] = None
):
"""Save the best model"""
# Copy checkpoint to best model directory
for filename in os.listdir(checkpoint_dir):
src = os.path.join(checkpoint_dir, filename)
dst = os.path.join(self.best_model_dir, filename)
if os.path.isfile(src):
shutil.copy2(src, dst)
# Save best model info
best_info = {
'step': step,
'source_checkpoint': checkpoint_dir,
'timestamp': datetime.now().isoformat()
}
if metrics:
best_info['metrics'] = metrics
best_info_path = os.path.join(self.best_model_dir, 'best_model_info.json')
with open(best_info_path, 'w') as f:
json.dump(best_info, f, indent=2)
logger.info(f"Saved best model from step {step}")
def _cleanup_checkpoints(self):
"""Remove old checkpoints to maintain save_total_limit"""
if self.save_total_limit is not None and len(self.checkpoints) > self.save_total_limit:
# Remove oldest checkpoints
checkpoints_to_remove = self.checkpoints[:-self.save_total_limit]
for step, checkpoint_dir in checkpoints_to_remove:
if os.path.exists(checkpoint_dir):
shutil.rmtree(checkpoint_dir)
logger.info(f"Removed old checkpoint: {checkpoint_dir}")
# Update checkpoint list
self.checkpoints = self.checkpoints[-self.save_total_limit:]
def load_checkpoint(
self,
checkpoint_path: Optional[str] = None,
load_best: bool = False,
map_location: str = 'cpu',
strict: bool = True
) -> Dict[str, Any]:
"""
Load checkpoint with robust error handling
Args:
checkpoint_path: Path to specific checkpoint
load_best: Whether to load best model
map_location: Device to map tensors to
strict: Whether to strictly enforce state dict matching
Returns:
Loaded state dictionary
Raises:
RuntimeError: If checkpoint loading fails
"""
try:
if load_best:
checkpoint_path = self.best_model_dir
elif checkpoint_path is None:
# Load latest checkpoint
if self.checkpoints:
checkpoint_path = self.checkpoints[-1][1]
else:
raise ValueError("No checkpoints found")
# Ensure checkpoint_path is a string
if checkpoint_path is None:
raise ValueError("Checkpoint path is None")
checkpoint_path = str(checkpoint_path) # Ensure it's a string
if not os.path.exists(checkpoint_path):
raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}")
logger.info(f"Loading checkpoint from {checkpoint_path}")
# Load training state with error handling
state_dict = {}
state_path = os.path.join(checkpoint_path, 'training_state.json')
if os.path.exists(state_path):
try:
with open(state_path, 'r') as f:
training_state = json.load(f)
state_dict.update(training_state)
except Exception as e:
logger.warning(f"Failed to load training state: {e}")
training_state = {}
else:
training_state = {}
# Load model state
model_path = os.path.join(checkpoint_path, 'pytorch_model.bin')
if os.path.exists(model_path):
try:
state_dict['model_state_dict'] = torch.load(
model_path,
map_location=map_location,
weights_only=True # Security: prevent arbitrary code execution
)
except Exception as e:
# Try legacy loading without weights_only for backward compatibility
try:
state_dict['model_state_dict'] = torch.load(
model_path,
map_location=map_location
)
logger.warning("Loaded checkpoint using legacy mode (weights_only=False)")
except Exception as legacy_e:
raise RuntimeError(f"Failed to load model state: {e}, Legacy attempt: {legacy_e}")
else:
logger.warning(f"Model state not found at {model_path}")
# Load optimizer state
optimizer_path = os.path.join(checkpoint_path, 'optimizer.pt')
if os.path.exists(optimizer_path):
try:
state_dict['optimizer_state_dict'] = torch.load(
optimizer_path,
map_location=map_location,
weights_only=True
)
except Exception as e:
# Try legacy loading
try:
state_dict['optimizer_state_dict'] = torch.load(
optimizer_path,
map_location=map_location
)
except Exception as legacy_e:
logger.warning(f"Failed to load optimizer state: {e}")
# Load scheduler state
scheduler_path = os.path.join(checkpoint_path, 'scheduler.pt')
if os.path.exists(scheduler_path):
try:
state_dict['scheduler_state_dict'] = torch.load(
scheduler_path,
map_location=map_location,
weights_only=True
)
except Exception as e:
# Try legacy loading
try:
state_dict['scheduler_state_dict'] = torch.load(
scheduler_path,
map_location=map_location
)
except Exception as legacy_e:
logger.warning(f"Failed to load scheduler state: {e}")
# Load config if available
config_path = os.path.join(checkpoint_path, 'config.json')
if os.path.exists(config_path):
try:
with open(config_path, 'r') as f:
state_dict['config'] = json.load(f)
except Exception as e:
logger.warning(f"Failed to load config: {e}")
logger.info(f"Successfully loaded checkpoint from {checkpoint_path}")
return state_dict
except Exception as e:
logger.error(f"Failed to load checkpoint: {e}")
raise RuntimeError(f"Checkpoint loading failed: {e}")
def get_latest_checkpoint(self) -> Optional[str]:
"""Get path to latest checkpoint"""
if self.checkpoints:
return self.checkpoints[-1][1]
return None
def get_best_checkpoint(self) -> str:
"""Get path to best model checkpoint"""
return self.best_model_dir
def checkpoint_exists(self, step: int) -> bool:
"""Check if checkpoint exists for given step"""
checkpoint_dir = os.path.join(self.output_dir, f"checkpoint-{step}")
return os.path.exists(checkpoint_dir)
def _safe_torch_save(self, obj: Any, path: str):
"""Save torch object with error handling"""
temp_path = f"{path}.tmp"
try:
torch.save(obj, temp_path)
# Atomic rename
os.replace(temp_path, path)
except Exception as e:
if os.path.exists(temp_path):
os.remove(temp_path)
raise RuntimeError(f"Failed to save {path}: {e}")
def _safe_json_save(self, obj: Any, path: str):
"""Save JSON with atomic write"""
temp_path = f"{path}.tmp"
try:
with open(temp_path, 'w') as f:
json.dump(obj, f, indent=2)
# Atomic rename
os.replace(temp_path, path)
except Exception as e:
if os.path.exists(temp_path):
os.remove(temp_path)
raise RuntimeError(f"Failed to save {path}: {e}")
def save_model_for_inference(
model: torch.nn.Module,
tokenizer: Any,
output_dir: str,
model_name: str = "model",
save_config: bool = True,
additional_files: Optional[Dict[str, str]] = None
):
"""
Save model in a format ready for inference
Args:
model: The model to save
tokenizer: Tokenizer to save
output_dir: Output directory
model_name: Name for the model
save_config: Whether to save model config
additional_files: Additional files to include
"""
os.makedirs(output_dir, exist_ok=True)
# Save model
model_path = os.path.join(output_dir, f"{model_name}.pt")
torch.save(model.state_dict(), model_path)
logger.info(f"Saved model to {model_path}")
# Save tokenizer
if hasattr(tokenizer, 'save_pretrained'):
tokenizer.save_pretrained(output_dir)
logger.info(f"Saved tokenizer to {output_dir}")
# Save config if available
if save_config and hasattr(model, 'config'):
config_path = os.path.join(output_dir, 'config.json')
model.config.save_pretrained(output_dir) # type: ignore[attr-defined]
# Save additional files
if additional_files:
for filename, content in additional_files.items():
filepath = os.path.join(output_dir, filename)
with open(filepath, 'w') as f:
if isinstance(content, dict):
json.dump(content, f, indent=2)
else:
f.write(str(content))
# Create README
readme_content = f"""# {model_name} Model
This directory contains a fine-tuned model ready for inference.
## Files:
- `{model_name}.pt`: Model weights
- `tokenizer_config.json`: Tokenizer configuration
- `special_tokens_map.json`: Special tokens mapping
- `vocab.txt` or `tokenizer.json`: Vocabulary file
## Loading the model:
```python
import torch
from transformers import AutoTokenizer
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained('{output_dir}')
# Load model
model = YourModelClass() # Initialize your model class
model.load_state_dict(torch.load('{model_name}.pt'))
model.eval()
```
## Model Information:
- Model type: {model.__class__.__name__}
- Saved on: {datetime.now().isoformat()}
"""
readme_path = os.path.join(output_dir, 'README.md')
with open(readme_path, 'w') as f:
f.write(readme_content)
logger.info(f"Model saved for inference in {output_dir}")

518
utils/config_loader.py Normal file
View File

@@ -0,0 +1,518 @@
"""
Centralized configuration loader for BGE fine-tuning project
"""
import os
import toml
import logging
from typing import Dict, Any, Optional, Tuple
from pathlib import Path
logger = logging.getLogger(__name__)
def resolve_model_path(
model_key: str,
config_instance: 'ConfigLoader',
model_name_or_path: Optional[str] = None,
cache_dir: Optional[str] = None
) -> Tuple[str, str]:
"""
Resolve model path with automatic local/remote fallback
Args:
model_key: Key in model_paths config (e.g., 'bge_m3', 'bge_reranker')
config_instance: ConfigLoader instance
model_name_or_path: Optional override from CLI
cache_dir: Optional cache directory
Returns:
Tuple of (resolved_model_path, resolved_cache_dir)
"""
# Default remote model names for each model type
remote_models = {
'bge_m3': 'BAAI/bge-m3',
'bge_reranker': 'BAAI/bge-reranker-base',
}
# Get model source (huggingface or modelscope)
model_source = config_instance.get('model_paths', 'source', 'huggingface')
# If ModelScope is used, adjust model names
if model_source == 'modelscope':
remote_models = {
'bge_m3': 'damo/nlp_gte_sentence-embedding_chinese-base',
'bge_reranker': 'damo/nlp_corom_passage-ranking_english-base',
}
# 1. Check CLI override first (only if explicitly provided)
if model_name_or_path is not None:
resolved_model_path = model_name_or_path
logger.info(f"[CLI OVERRIDE] Using model path: {resolved_model_path}")
else:
# 2. Check local path from config
local_path_raw = config_instance.get('model_paths', model_key, None)
# Handle list/string conversion
if isinstance(local_path_raw, list):
local_path = local_path_raw[0] if local_path_raw else None
else:
local_path = local_path_raw
if local_path and isinstance(local_path, str) and os.path.exists(local_path):
# Check if it's a valid model directory
if _is_valid_model_directory(local_path):
resolved_model_path = local_path
logger.info(f"[LOCAL] Using local model: {resolved_model_path}")
else:
logger.warning(f"[LOCAL] Invalid model directory: {local_path}, falling back to remote")
resolved_model_path = remote_models.get(model_key, f'BAAI/{model_key}')
else:
# 3. Fall back to remote model
resolved_model_path = remote_models.get(model_key, f'BAAI/{model_key}')
logger.info(f"[REMOTE] Using remote model: {resolved_model_path}")
# If local path was configured but doesn't exist, we'll download to that location
if local_path and isinstance(local_path, str):
logger.info(f"[DOWNLOAD] Will download to configured path: {local_path}")
# Create directory if it doesn't exist
os.makedirs(local_path, exist_ok=True)
# Resolve cache directory
if cache_dir is not None:
resolved_cache_dir = cache_dir
else:
cache_dir_raw = config_instance.get('model_paths', 'cache_dir', './cache/data')
if isinstance(cache_dir_raw, list):
resolved_cache_dir = cache_dir_raw[0] if cache_dir_raw else './cache/data'
else:
resolved_cache_dir = cache_dir_raw if cache_dir_raw else './cache/data'
# Ensure cache directory exists
if isinstance(resolved_cache_dir, str):
os.makedirs(resolved_cache_dir, exist_ok=True)
return resolved_model_path, resolved_cache_dir
def _is_valid_model_directory(path: str) -> bool:
"""Check if a directory contains a valid model"""
required_files = ['config.json']
optional_files = ['tokenizer_config.json', 'vocab.txt', 'tokenizer.json']
model_files = ['pytorch_model.bin', 'model.safetensors', 'pytorch_model.bin.index.json']
path_obj = Path(path)
if not path_obj.exists() or not path_obj.is_dir():
return False
# Check for required files
for file in required_files:
if not (path_obj / file).exists():
return False
# Check for at least one model file
has_model_file = any((path_obj / file).exists() for file in model_files)
if not has_model_file:
return False
# Check for at least one tokenizer file (for embedding models)
has_tokenizer_file = any((path_obj / file).exists() for file in optional_files)
if not has_tokenizer_file:
logger.warning(f"Model directory {path} missing tokenizer files, but proceeding anyway")
return True
def download_model_if_needed(
model_name_or_path: str,
local_path: Optional[str] = None,
cache_dir: Optional[str] = None,
model_source: str = 'huggingface'
) -> str:
"""
Download model if it's a remote identifier and local path is specified
Args:
model_name_or_path: Model identifier or path
local_path: Target local path for download
cache_dir: Cache directory
model_source: 'huggingface' or 'modelscope'
Returns:
Final model path (local if downloaded, original if already local)
"""
# If it's already a local path, return as is
if os.path.exists(model_name_or_path):
return model_name_or_path
# If no local path specified, return original (will be handled by transformers)
if not local_path:
return model_name_or_path
# Check if local path already exists and is valid
if os.path.exists(local_path) and _is_valid_model_directory(local_path):
logger.info(f"Model already exists locally: {local_path}")
return local_path
# Download the model
try:
logger.info(f"Downloading model {model_name_or_path} to {local_path}")
if model_source == 'huggingface':
from transformers import AutoModel, AutoTokenizer
# Download model and tokenizer
model = AutoModel.from_pretrained(
model_name_or_path,
cache_dir=cache_dir,
trust_remote_code=True
)
tokenizer = AutoTokenizer.from_pretrained(
model_name_or_path,
cache_dir=cache_dir,
trust_remote_code=True
)
# Save to local path
os.makedirs(local_path, exist_ok=True)
model.save_pretrained(local_path)
tokenizer.save_pretrained(local_path)
logger.info(f"Model downloaded successfully to {local_path}")
return local_path
elif model_source == 'modelscope':
try:
from modelscope import snapshot_download # type: ignore
# Download using ModelScope
downloaded_path = snapshot_download(
model_name_or_path,
cache_dir=cache_dir,
local_dir=local_path
)
logger.info(f"Model downloaded successfully from ModelScope to {downloaded_path}")
return downloaded_path
except ImportError:
logger.warning("ModelScope not available, falling back to HuggingFace")
return download_model_if_needed(
model_name_or_path, local_path, cache_dir, 'huggingface'
)
except Exception as e:
logger.error(f"Failed to download model {model_name_or_path}: {e}")
logger.info("Returning original model path - transformers will handle caching")
# Fallback return
return model_name_or_path
class ConfigLoader:
"""Centralized configuration management for BGE fine-tuning
Parameter Precedence (highest to lowest):
1. Command-line arguments (CLI)
2. Model-specific config ([m3], [reranker])
3. [hardware] section in config.toml
4. [training] section in config.toml
5. Hardcoded defaults in code
"""
_instance = None
_config = None
def __new__(cls):
if cls._instance is None:
cls._instance = super(ConfigLoader, cls).__new__(cls)
return cls._instance
def __init__(self):
if self._config is None:
self._load_config()
def _find_config_file(self) -> str:
"""Find config.toml file in project hierarchy"""
# Check environment variable FIRST (highest priority)
if 'BGE_CONFIG_PATH' in os.environ:
config_path = Path(os.environ['BGE_CONFIG_PATH'])
if config_path.exists():
return str(config_path)
else:
logger.warning(f"BGE_CONFIG_PATH set to {config_path} but file doesn't exist, falling back to default")
current_dir = Path(__file__).parent
# Search up the directory tree for config.toml
for _ in range(5): # Limit search depth
config_path = current_dir / "config.toml"
if config_path.exists():
return str(config_path)
current_dir = current_dir.parent
# Default location
return os.path.join(os.getcwd(), "config.toml")
def _load_config(self):
"""Load configuration from TOML file"""
config_path = self._find_config_file()
try:
if os.path.exists(config_path):
with open(config_path, 'r', encoding='utf-8') as f:
loaded = toml.load(f)
if loaded is None:
loaded = {}
self._config = dict(loaded)
logger.info(f"Loaded configuration from {config_path}")
else:
logger.warning(f"Config file not found at {config_path}, using defaults")
self._config = self._get_default_config()
self._save_default_config(config_path)
logger.warning("Default config created. Please check your configuration file.")
except Exception as e:
logger.error(f"Error loading config: {e}")
self._config = self._get_default_config()
# Validation for unset/placeholder values
self._validate_config()
def _validate_config(self):
"""Validate config for unset/placeholder values"""
api_key = self.get('api', 'api_key', None)
if not api_key or api_key == 'your-api-key-here':
logger.warning("API key is unset or placeholder. Set it in config.toml or via the BGE_API_KEY environment variable.")
def _get_default_config(self) -> Dict[str, Any]:
"""Get default configuration"""
return {
'api': {
'base_url': 'https://api.deepseek.com/v1',
'api_key': 'your-api-key-here',
'model': 'deepseek-chat',
'max_retries': 3,
'timeout': 30,
'temperature': 0.7
},
'training': {
'default_batch_size': 4,
'default_learning_rate': 1e-5,
'default_num_epochs': 3,
'default_warmup_ratio': 0.1,
'default_weight_decay': 0.01,
'gradient_accumulation_steps': 4
},
'model_paths': {
'bge_m3': './models/bge-m3',
'bge_reranker': './models/bge-reranker-base',
'cache_dir': './cache/data'
},
'data': {
'max_query_length': 64,
'max_passage_length': 512,
'max_seq_length': 512,
'validation_split': 0.1,
'random_seed': 42
},
'augmentation': {
'enable_query_augmentation': False,
'enable_passage_augmentation': False,
'augmentation_methods': ['paraphrase', 'back_translation'],
'num_augmentations_per_sample': 2,
'augmentation_temperature': 0.8
},
'hard_negative_mining': {
'enable_mining': True,
'num_hard_negatives': 15,
'negative_range_min': 10,
'negative_range_max': 100,
'margin': 0.1,
'index_refresh_interval': 1000,
'index_type': 'IVF',
'nlist': 100
},
'evaluation': {
'eval_batch_size': 32,
'k_values': [1, 3, 5, 10, 20, 50, 100],
'metrics': ['recall', 'precision', 'map', 'mrr', 'ndcg'],
'save_predictions': True,
'predictions_dir': './output/predictions'
},
'logging': {
'log_level': 'INFO',
'log_to_file': True,
'log_dir': './logs',
'tensorboard_dir': './logs/tensorboard',
'wandb_project': 'bge-finetuning',
'wandb_entity': '',
'use_wandb': False
},
'distributed': {
'backend': 'nccl',
'init_method': 'env://',
'find_unused_parameters': True
},
'optimization': {
'gradient_checkpointing': True,
'fp16': True,
'bf16': False,
'fp16_opt_level': 'O1',
'max_grad_norm': 1.0,
'adam_epsilon': 1e-8,
'adam_beta1': 0.9,
'adam_beta2': 0.999
},
'hardware': {
'cuda_visible_devices': '0,1,2,3',
'dataloader_num_workers': 4,
'dataloader_pin_memory': True,
'per_device_train_batch_size': 4,
'per_device_eval_batch_size': 8
},
'checkpoint': {
'save_strategy': 'steps',
'save_steps': 1000,
'save_total_limit': 3,
'load_best_model_at_end': True,
'metric_for_best_model': 'eval_recall@10',
'greater_is_better': True,
'resume_from_checkpoint': ''
},
'export': {
'export_format': ['pytorch', 'onnx'],
'optimize_for_inference': True,
'quantize': False,
'quantization_bits': 8
}
}
def _save_default_config(self, config_path: str):
"""Save default configuration to file"""
try:
os.makedirs(os.path.dirname(config_path), exist_ok=True)
with open(config_path, 'w', encoding='utf-8') as f:
config_to_save = self._config if isinstance(self._config, dict) else {}
toml.dump(config_to_save, f)
logger.info(f"Created default config at {config_path}")
except Exception as e:
logger.error(f"Error saving default config: {e}")
def get(self, section: str, key: str, default=None):
"""Get a config value, logging its source (CLI, config, default)."""
# CLI overrides are handled outside this loader; here we check config and default
value = default
source = 'default'
if isinstance(self._config, dict) and section in self._config and key in self._config[section]:
value = self._config[section][key]
source = 'config.toml'
# Robust list/tuple conversion
if isinstance(value, str) and (',' in value or ';' in value):
try:
value = [v.strip() for v in value.replace(';', ',').split(',') if v.strip()]
source += ' (parsed as list)'
except Exception as e:
logging.warning(f"Failed to parse list for {section}.{key}: {e}")
log_func = logging.info if source != 'default' else logging.warning
log_func(f"Config param {section}.{key} = {value} (source: {source})")
return value
def get_section(self, section: str) -> Dict[str, Any]:
"""Get entire configuration section"""
if isinstance(self._config, dict):
return self._config.get(section, {})
return {}
def update(self, key: str, value: Any):
"""Update configuration value. Logs a warning if overriding an existing value."""
if self._config is None:
self._config = {}
keys = key.split('.')
config = self._config
# Navigate to the parent of the final key
for k in keys[:-1]:
if k not in config:
config[k] = {}
config = config[k]
# Log if overriding
old_value = config.get(keys[-1], None)
if old_value is not None and old_value != value:
logger.warning(f"Parameter '{key}' overridden: {old_value} -> {value} (higher-precedence source)")
config[keys[-1]] = value
def reload(self):
"""Reload configuration from file"""
self._config = None
self._load_config()
# Global instance
config = ConfigLoader()
def get_config() -> ConfigLoader:
"""Get global configuration instance"""
return config
def load_config_for_training(
config_path: Optional[str] = None,
override_params: Optional[Dict[str, Any]] = None
) -> Dict[str, Any]:
"""
Load configuration specifically for training scripts
Args:
config_path: Optional path to config file
override_params: Parameters to override from command line
Returns:
Dictionary with all configuration parameters
"""
global config
# Reload with specific path if provided
if config_path:
old_path = os.environ.get('BGE_CONFIG_PATH')
os.environ['BGE_CONFIG_PATH'] = config_path
config.reload()
if old_path:
os.environ['BGE_CONFIG_PATH'] = old_path
else:
del os.environ['BGE_CONFIG_PATH']
# Get all configuration
training_config = {
'api': config.get_section('api'),
'training': config.get_section('training'),
'model_paths': config.get_section('model_paths'),
'data': config.get_section('data'),
'augmentation': config.get_section('augmentation'),
'hard_negative_mining': config.get_section('hard_negative_mining'),
'evaluation': config.get_section('evaluation'),
'logging': config.get_section('logging'),
'distributed': config.get_section('distributed'),
'optimization': config.get_section('optimization'),
'hardware': config.get_section('hardware'),
'checkpoint': config.get_section('checkpoint'),
'export': config.get_section('export'),
'm3': config.get_section('m3'),
'reranker': config.get_section('reranker'),
}
# Apply overrides if provided
if override_params:
for key, value in override_params.items():
config.update(key, value)
# Also update the returned dict
keys = key.split('.')
current = training_config
for k in keys[:-1]:
if k not in current:
current[k] = {}
current = current[k]
current[keys[-1]] = value
return training_config

425
utils/distributed.py Normal file
View File

@@ -0,0 +1,425 @@
import os
import torch
import torch.backends.cudnn as cudnn
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader, DistributedSampler
import logging
from typing import Optional, Any, List
import random
import numpy as np
from .config_loader import get_config
# Import Ascend support utilities
try:
from .ascend_support import (
check_ascend_availability,
get_ascend_backend,
setup_ascend_device,
clear_ascend_cache
)
ASCEND_SUPPORT_AVAILABLE = True
except ImportError:
ASCEND_SUPPORT_AVAILABLE = False
logger = logging.getLogger(__name__)
def setup_distributed(config=None):
"""Setup distributed training environment with proper error handling."""
try:
# Check if we're in a distributed environment
world_size = int(os.environ.get('WORLD_SIZE', 1))
rank = int(os.environ.get('RANK', 0))
local_rank = int(os.environ.get('LOCAL_RANK', 0))
if world_size == 1:
logger.info("Single GPU/CPU training mode")
return
# Determine backend based on hardware availability
# First check for Ascend NPU support
if ASCEND_SUPPORT_AVAILABLE and check_ascend_availability():
backend = 'hccl'
try:
setup_ascend_device(local_rank)
logger.info("Using Ascend NPU with HCCL backend")
except Exception as e:
logger.warning(f"Failed to setup Ascend device: {e}, falling back")
backend = 'gloo'
elif torch.cuda.is_available():
backend = 'nccl'
# Set CUDA device
torch.cuda.set_device(local_rank)
else:
backend = 'gloo'
# Override backend from config if provided
if config is not None:
backend = config.get('distributed.backend', backend)
# Initialize process group
init_method = os.environ.get('MASTER_ADDR', None)
if init_method:
init_method = f"tcp://{os.environ['MASTER_ADDR']}:{os.environ.get('MASTER_PORT', '29500')}"
else:
init_method = 'env://'
dist.init_process_group(
backend=backend,
init_method=init_method,
world_size=world_size,
rank=rank
)
logger.info(f"Distributed training initialized: backend={backend}, world_size={world_size}, rank={rank}, local_rank={local_rank}")
except Exception as e:
logger.warning(f"Failed to initialize distributed training: {e}")
logger.info("Falling back to single GPU/CPU training")
# Set environment variables to indicate single GPU mode
os.environ['WORLD_SIZE'] = '1'
os.environ['RANK'] = '0'
os.environ['LOCAL_RANK'] = '0'
def cleanup_distributed():
"""Cleanup distributed training environment with error handling."""
try:
if dist.is_initialized():
dist.destroy_process_group()
logger.info("Distributed process group destroyed.")
except Exception as e:
logger.warning(f"Error during distributed cleanup: {e}")
# Continue execution even if cleanup fails
def is_main_process() -> bool:
"""Check if this is the main process"""
if not dist.is_initialized():
return True
return dist.get_rank() == 0
def get_rank() -> int:
"""Get current process rank"""
if not dist.is_initialized():
return 0
return dist.get_rank()
def get_world_size() -> int:
"""Get world size"""
if not dist.is_initialized():
return 1
return dist.get_world_size()
def get_local_rank() -> int:
"""Get local rank (GPU index on current node)"""
if not dist.is_initialized():
return 0
return int(os.environ.get('LOCAL_RANK', 0))
def synchronize():
"""Synchronize all processes"""
if dist.is_initialized():
dist.barrier()
def all_reduce(tensor: torch.Tensor, op=dist.ReduceOp.SUM) -> torch.Tensor:
"""
All-reduce a tensor across all processes
Args:
tensor: Tensor to reduce
op: Reduction operation
Returns:
Reduced tensor
"""
if not dist.is_initialized():
return tensor
dist.all_reduce(tensor, op=op)
return tensor
def all_gather(tensor: torch.Tensor) -> List[torch.Tensor]:
"""
All-gather tensors from all processes
Args:
tensor: Tensor to gather
Returns:
List of tensors from all processes
"""
if not dist.is_initialized():
return [tensor]
world_size = get_world_size()
tensors = [torch.empty_like(tensor) for _ in range(world_size)]
dist.all_gather(tensors, tensor)
return tensors
def broadcast(tensor: torch.Tensor, src: int = 0) -> torch.Tensor:
"""
Broadcast tensor from source to all processes
Args:
tensor: Tensor to broadcast
src: Source rank
Returns:
Broadcasted tensor
"""
if not dist.is_initialized():
return tensor
dist.broadcast(tensor, src=src)
return tensor
def reduce_dict(input_dict: dict, average: bool = True) -> dict:
"""
Reduce a dictionary of values across all processes
Args:
input_dict: Dictionary to reduce
average: Whether to average the values
Returns:
Reduced dictionary
"""
if not dist.is_initialized():
return input_dict
world_size = get_world_size()
if world_size < 2:
return input_dict
with torch.no_grad():
names = []
values = []
for k in sorted(input_dict.keys()):
names.append(k)
values.append(input_dict[k])
values = torch.stack(values, dim=0)
dist.all_reduce(values)
if average:
values /= world_size
reduced_dict = {k: v.item() for k, v in zip(names, values)}
return reduced_dict
def setup_model_for_distributed(
model: torch.nn.Module,
device_ids: Optional[List[int]] = None,
output_device: Optional[int] = None,
find_unused_parameters: bool = True,
broadcast_buffers: bool = True
) -> torch.nn.Module:
"""
Wrap model for distributed training
Args:
model: Model to wrap
device_ids: GPU indices to use
output_device: Device where outputs are gathered
find_unused_parameters: Whether to find unused parameters
broadcast_buffers: Whether to broadcast buffers
Returns:
Distributed model
"""
if not dist.is_initialized():
return model
if device_ids is None:
device_ids = [get_local_rank()]
if output_device is None:
output_device = device_ids[0]
model = DDP(
model,
device_ids=device_ids,
output_device=output_device,
find_unused_parameters=find_unused_parameters,
broadcast_buffers=broadcast_buffers
)
logger.info(f"Wrapped model for distributed training on devices {device_ids}")
return model
def create_distributed_dataloader(
dataset,
batch_size: int,
shuffle: bool = True,
num_workers: int = 0,
pin_memory: bool = True,
drop_last: bool = False,
seed: int = 42
) -> DataLoader:
"""
Create a distributed dataloader
Args:
dataset: Dataset to load
batch_size: Batch size per GPU
shuffle: Whether to shuffle
num_workers: Number of data loading workers
pin_memory: Whether to pin memory
drop_last: Whether to drop last incomplete batch
seed: Random seed for shuffling
Returns:
Distributed dataloader
"""
if dist.is_initialized():
sampler = DistributedSampler(
dataset,
num_replicas=get_world_size(),
rank=get_rank(),
shuffle=shuffle,
seed=seed,
drop_last=drop_last
)
shuffle = False # Sampler handles shuffling
else:
sampler = None
dataloader = DataLoader(
dataset,
batch_size=batch_size,
shuffle=shuffle,
sampler=sampler,
num_workers=num_workers,
pin_memory=pin_memory,
drop_last=drop_last
)
return dataloader
def set_random_seed(seed: int, deterministic: bool = False):
"""
Set random seed for reproducibility
Args:
seed: Random seed
deterministic: Whether to use deterministic algorithms
"""
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
if deterministic and torch.cuda.is_available():
# Configure cuDNN determinism settings
cudnn.deterministic = True
cudnn.benchmark = False
logger.info(f"Set random seed to {seed} (deterministic={deterministic})")
def print_distributed_info():
"""Print distributed training information"""
if not dist.is_initialized():
logger.info("Distributed training not initialized")
return
logger.info(f"Distributed Info:")
logger.info(f" Backend: {dist.get_backend()}")
logger.info(f" World size: {get_world_size()}")
logger.info(f" Rank: {get_rank()}")
logger.info(f" Local rank: {get_local_rank()}")
logger.info(f" Master addr: {os.environ.get('MASTER_ADDR', 'N/A')}")
logger.info(f" Master port: {os.environ.get('MASTER_PORT', 'N/A')}")
class DistributedMetricAggregator:
"""Helper class to aggregate metrics across distributed processes"""
def __init__(self):
self.metrics = {}
def update(self, **kwargs):
"""Update metrics"""
for k, v in kwargs.items():
if isinstance(v, torch.Tensor):
v = v.detach()
else:
v = torch.tensor(v, dtype=torch.float32)
if torch.cuda.is_available():
v = v.cuda()
self.metrics[k] = v
def aggregate(self, average: bool = True) -> dict:
"""Aggregate metrics across all processes"""
if not dist.is_initialized():
return {k: v.item() if isinstance(v, torch.Tensor) else v
for k, v in self.metrics.items()}
aggregated = {}
world_size = get_world_size()
for k, v in self.metrics.items():
# All-reduce
all_reduce(v)
# Average if requested
if average:
v = v / world_size
aggregated[k] = v.item()
return aggregated
def reset(self):
"""Reset metrics"""
self.metrics = {}
def save_on_master(*args, **kwargs):
"""Decorator to only save on master process"""
def decorator(func):
def wrapper(*args, **kwargs):
if is_main_process():
return func(*args, **kwargs)
else:
# Return None on non-master processes
return None
return wrapper
return decorator
def log_on_master(logger: logging.Logger):
"""Get a logger that only logs on master process"""
class MasterOnlyLogger:
def __init__(self, logger):
self.logger = logger
def __getattr__(self, name):
if is_main_process():
return getattr(self.logger, name)
else:
return lambda *args, **kwargs: None
return MasterOnlyLogger(logger)

344
utils/logging.py Normal file
View File

@@ -0,0 +1,344 @@
import logging
import os
import sys
from datetime import datetime
from typing import Optional, Dict, Any
import json
import torch
from tqdm import tqdm
import numpy as np
def create_progress_bar(
iterable,
desc: str = "",
total: Optional[int] = None,
disable: bool = False,
position: int = 0,
leave: bool = True,
ncols: Optional[int] = None,
unit: str = "it",
dynamic_ncols: bool = True,
):
"""
Create a tqdm progress bar with consistent settings across the project.
Args:
iterable: Iterable to wrap with progress bar
desc: Description for the progress bar
total: Total number of iterations (if None, will try to get from iterable)
disable: Whether to disable the progress bar
position: Position of the progress bar (for multiple bars)
leave: Whether to leave the progress bar after completion
ncols: Width of the progress bar
unit: Unit of iteration (e.g., "batch", "step")
dynamic_ncols: Automatically adjust width based on terminal size
Returns:
tqdm progress bar object
"""
# Check if we're in distributed training
try:
import torch.distributed as dist
if dist.is_initialized() and dist.get_rank() != 0:
# Disable progress bar for non-main processes
disable = True
except:
pass
# Create progress bar with consistent settings
return tqdm(
iterable,
desc=desc,
total=total,
disable=disable,
position=position,
leave=leave,
ncols=ncols,
unit=unit,
dynamic_ncols=dynamic_ncols,
# Use custom write method to avoid conflicts with logging
file=sys.stdout,
# Ensure smooth updates
smoothing=0.1,
# Format with consistent style
bar_format='{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}{postfix}]'
)
class TqdmLoggingHandler(logging.Handler):
"""Custom logging handler that plays nicely with tqdm progress bars"""
def emit(self, record):
try:
msg = self.format(record)
# Use tqdm.write to avoid breaking progress bars
tqdm.write(msg, file=sys.stdout)
self.flush()
except Exception:
self.handleError(record)
def setup_logging(
output_dir: Optional[str] = None,
log_to_file: bool = False,
log_level: int = logging.INFO,
log_to_console: bool = True,
use_tqdm_handler: bool = True,
rank: int = 0,
world_size: int = 1,
):
"""
Setup logging with optional file and console handlers.
Args:
output_dir: Directory for log files
log_to_file: Whether to log to file
log_level: Logging level
log_to_console: Whether to log to console
use_tqdm_handler: Use tqdm-compatible handler for console logging
rank: Process rank for distributed training
world_size: World size for distributed training
"""
handlers: list[logging.Handler] = []
log_path = None
# Only log to console from main process in distributed training
if log_to_console and (world_size == 1 or rank == 0):
if use_tqdm_handler:
# Use tqdm-compatible handler
console_handler = TqdmLoggingHandler()
else:
console_handler = logging.StreamHandler(sys.stdout)
console_handler.setLevel(log_level)
handlers.append(console_handler)
if log_to_file:
if output_dir is None:
output_dir = "./logs"
os.makedirs(output_dir, exist_ok=True)
# Add rank to filename for distributed training
if world_size > 1:
log_filename = f"train_rank{rank}.log"
else:
log_filename = "train.log"
log_path = os.path.join(output_dir, log_filename)
file_handler = logging.FileHandler(log_path)
file_handler.setLevel(log_level)
handlers.append(file_handler)
# Create formatter
if world_size > 1:
# Include rank in log format for distributed training
formatter = logging.Formatter(
f'[%(asctime)s][Rank {rank}] %(levelname)s %(name)s: %(message)s',
datefmt='%Y-%m-%d %H:%M:%S'
)
else:
formatter = logging.Formatter(
'[%(asctime)s] %(levelname)s %(name)s: %(message)s',
datefmt='%Y-%m-%d %H:%M:%S'
)
# Apply formatter to all handlers
for handler in handlers:
handler.setFormatter(formatter)
# Configure root logger
root_logger = logging.getLogger()
root_logger.setLevel(log_level)
# Remove existing handlers to avoid duplicates
for handler in root_logger.handlers[:]:
root_logger.removeHandler(handler)
# Add new handlers
for handler in handlers:
root_logger.addHandler(handler)
# Log initialization message
if rank == 0:
logging.info(
f"Logging initialized. Level: {logging.getLevelName(log_level)} | "
f"Log file: {log_path if log_to_file else 'none'} | "
f"World size: {world_size}"
)
def get_logger(name: str) -> logging.Logger:
"""Get a logger with the specified name"""
return logging.getLogger(name)
class MetricsLogger:
"""Logger specifically for training metrics"""
def __init__(self, output_dir: str):
self.output_dir = output_dir
self.metrics_file = os.path.join(output_dir, 'metrics.jsonl')
os.makedirs(output_dir, exist_ok=True)
def log(self, step: int, metrics: Dict[str, Any], phase: str = 'train'):
"""Log metrics to file"""
entry = {
'step': step,
'phase': phase,
'timestamp': datetime.now().isoformat(),
'metrics': metrics
}
with open(self.metrics_file, 'a') as f:
f.write(json.dumps(entry) + '\n')
def log_config(self, config: Dict[str, Any]):
"""Log configuration"""
config_file = os.path.join(self.output_dir, 'config.json')
with open(config_file, 'w') as f:
json.dump(config, f, indent=2)
def log_summary(self, summary: Dict[str, Any]):
"""Log training summary"""
summary_file = os.path.join(self.output_dir, 'summary.json')
with open(summary_file, 'w') as f:
json.dump(summary, f, indent=2)
class TensorBoardLogger:
"""TensorBoard logging wrapper"""
def __init__(self, log_dir: str, enabled: bool = True):
self.enabled = enabled
if self.enabled:
try:
from torch.utils.tensorboard.writer import SummaryWriter
self.writer = SummaryWriter(log_dir)
except ImportError:
get_logger(__name__).warning(
"TensorBoard not available. Install with: pip install tensorboard"
)
self.enabled = False
def log_scalar(self, tag: str, value: float, step: int):
"""Log a scalar value"""
if self.enabled:
self.writer.add_scalar(tag, value, step)
def log_scalars(self, main_tag: str, tag_scalar_dict: Dict[str, float], step: int):
"""Log multiple scalars"""
if self.enabled:
self.writer.add_scalars(main_tag, tag_scalar_dict, step)
def log_histogram(self, tag: str, values: np.ndarray, step: int):
"""Log a histogram"""
if self.enabled:
self.writer.add_histogram(tag, values, step)
def log_embedding(self, embeddings: torch.Tensor, metadata: Optional[list] = None, step: int = 0):
"""Log embeddings for visualization"""
if self.enabled:
self.writer.add_embedding(embeddings, metadata=metadata, global_step=step)
def flush(self):
"""Flush the writer"""
if self.enabled:
self.writer.flush()
def close(self):
"""Close the writer"""
if self.enabled:
self.writer.close()
class ProgressLogger:
"""Helper for logging training progress"""
def __init__(self, total_steps: int, log_interval: int = 100):
self.total_steps = total_steps
self.log_interval = log_interval
self.logger = get_logger(__name__)
self.start_time = datetime.now()
def log_step(self, step: int, loss: float, metrics: Optional[Dict[str, float]] = None):
"""Log progress for a training step"""
if step % self.log_interval == 0:
elapsed = (datetime.now() - self.start_time).total_seconds()
steps_per_sec = step / elapsed if elapsed > 0 else 0
eta_seconds = (self.total_steps - step) / steps_per_sec if steps_per_sec > 0 else 0
msg = f"Step {step}/{self.total_steps} | Loss: {loss:.4f}"
if metrics:
metric_str = " | ".join([f"{k}: {v:.4f}" for k, v in metrics.items()])
msg += f" | {metric_str}"
msg += f" | Speed: {steps_per_sec:.2f} steps/s | ETA: {self._format_time(eta_seconds)}"
self.logger.info(msg)
def _format_time(self, seconds: float) -> str:
"""Format seconds to human readable time"""
hours = int(seconds // 3600)
minutes = int((seconds % 3600) // 60)
seconds = int(seconds % 60)
if hours > 0:
return f"{hours}h {minutes}m {seconds}s"
elif minutes > 0:
return f"{minutes}m {seconds}s"
else:
return f"{seconds}s"
def log_model_info(model: torch.nn.Module, logger: Optional[logging.Logger] = None):
"""Log model information"""
if logger is None:
logger = get_logger(__name__)
# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
logger.info(f"Model: {model.__class__.__name__}")
logger.info(f"Total parameters: {total_params:,}")
logger.info(f"Trainable parameters: {trainable_params:,}")
logger.info(f"Non-trainable parameters: {total_params - trainable_params:,}")
# Log parameter shapes
logger.debug("Parameter shapes:")
for name, param in model.named_parameters():
logger.debug(f" {name}: {list(param.shape)}")
def log_training_summary(
config: Dict[str, Any],
metrics: Dict[str, float],
elapsed_time: float,
output_dir: str
):
"""Log comprehensive training summary"""
logger = get_logger(__name__)
summary = {
'config': config,
'final_metrics': metrics,
'training_time': elapsed_time,
'timestamp': datetime.now().isoformat()
}
# Save to file
summary_file = os.path.join(output_dir, 'training_summary.json')
with open(summary_file, 'w') as f:
json.dump(summary, f, indent=2)
# Log to console
logger.info("=" * 50)
logger.info("TRAINING SUMMARY")
logger.info("=" * 50)
logger.info(f"Training time: {elapsed_time:.2f} seconds")
logger.info("Final metrics:")
for metric, value in metrics.items():
logger.info(f" {metric}: {value:.4f}")
logger.info(f"Summary saved to: {summary_file}")
logger.info("=" * 50)