import logging import os import sys from datetime import datetime from typing import Optional, Dict, Any import json import torch from tqdm import tqdm import numpy as np def create_progress_bar( iterable, desc: str = "", total: Optional[int] = None, disable: bool = False, position: int = 0, leave: bool = True, ncols: Optional[int] = None, unit: str = "it", dynamic_ncols: bool = True, ): """ Create a tqdm progress bar with consistent settings across the project. Args: iterable: Iterable to wrap with progress bar desc: Description for the progress bar total: Total number of iterations (if None, will try to get from iterable) disable: Whether to disable the progress bar position: Position of the progress bar (for multiple bars) leave: Whether to leave the progress bar after completion ncols: Width of the progress bar unit: Unit of iteration (e.g., "batch", "step") dynamic_ncols: Automatically adjust width based on terminal size Returns: tqdm progress bar object """ # Check if we're in distributed training try: import torch.distributed as dist if dist.is_initialized() and dist.get_rank() != 0: # Disable progress bar for non-main processes disable = True except: pass # Create progress bar with consistent settings return tqdm( iterable, desc=desc, total=total, disable=disable, position=position, leave=leave, ncols=ncols, unit=unit, dynamic_ncols=dynamic_ncols, # Use custom write method to avoid conflicts with logging file=sys.stdout, # Ensure smooth updates smoothing=0.1, # Format with consistent style bar_format='{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}{postfix}]' ) class TqdmLoggingHandler(logging.Handler): """Custom logging handler that plays nicely with tqdm progress bars""" def emit(self, record): try: msg = self.format(record) # Use tqdm.write to avoid breaking progress bars tqdm.write(msg, file=sys.stdout) self.flush() except Exception: self.handleError(record) def setup_logging( output_dir: Optional[str] = None, log_to_file: bool = False, log_level: int = logging.INFO, log_to_console: bool = True, use_tqdm_handler: bool = True, rank: int = 0, world_size: int = 1, ): """ Setup logging with optional file and console handlers. Args: output_dir: Directory for log files log_to_file: Whether to log to file log_level: Logging level log_to_console: Whether to log to console use_tqdm_handler: Use tqdm-compatible handler for console logging rank: Process rank for distributed training world_size: World size for distributed training """ handlers: list[logging.Handler] = [] log_path = None # Only log to console from main process in distributed training if log_to_console and (world_size == 1 or rank == 0): if use_tqdm_handler: # Use tqdm-compatible handler console_handler = TqdmLoggingHandler() else: console_handler = logging.StreamHandler(sys.stdout) console_handler.setLevel(log_level) handlers.append(console_handler) if log_to_file: if output_dir is None: output_dir = "./logs" os.makedirs(output_dir, exist_ok=True) # Add rank to filename for distributed training if world_size > 1: log_filename = f"train_rank{rank}.log" else: log_filename = "train.log" log_path = os.path.join(output_dir, log_filename) file_handler = logging.FileHandler(log_path) file_handler.setLevel(log_level) handlers.append(file_handler) # Create formatter if world_size > 1: # Include rank in log format for distributed training formatter = logging.Formatter( f'[%(asctime)s][Rank {rank}] %(levelname)s %(name)s: %(message)s', datefmt='%Y-%m-%d %H:%M:%S' ) else: formatter = logging.Formatter( '[%(asctime)s] %(levelname)s %(name)s: %(message)s', datefmt='%Y-%m-%d %H:%M:%S' ) # Apply formatter to all handlers for handler in handlers: handler.setFormatter(formatter) # Configure root logger root_logger = logging.getLogger() root_logger.setLevel(log_level) # Remove existing handlers to avoid duplicates for handler in root_logger.handlers[:]: root_logger.removeHandler(handler) # Add new handlers for handler in handlers: root_logger.addHandler(handler) # Log initialization message if rank == 0: logging.info( f"Logging initialized. Level: {logging.getLevelName(log_level)} | " f"Log file: {log_path if log_to_file else 'none'} | " f"World size: {world_size}" ) def get_logger(name: str) -> logging.Logger: """Get a logger with the specified name""" return logging.getLogger(name) class MetricsLogger: """Logger specifically for training metrics""" def __init__(self, output_dir: str): self.output_dir = output_dir self.metrics_file = os.path.join(output_dir, 'metrics.jsonl') os.makedirs(output_dir, exist_ok=True) def log(self, step: int, metrics: Dict[str, Any], phase: str = 'train'): """Log metrics to file""" entry = { 'step': step, 'phase': phase, 'timestamp': datetime.now().isoformat(), 'metrics': metrics } with open(self.metrics_file, 'a') as f: f.write(json.dumps(entry) + '\n') def log_config(self, config: Dict[str, Any]): """Log configuration""" config_file = os.path.join(self.output_dir, 'config.json') with open(config_file, 'w') as f: json.dump(config, f, indent=2) def log_summary(self, summary: Dict[str, Any]): """Log training summary""" summary_file = os.path.join(self.output_dir, 'summary.json') with open(summary_file, 'w') as f: json.dump(summary, f, indent=2) class TensorBoardLogger: """TensorBoard logging wrapper""" def __init__(self, log_dir: str, enabled: bool = True): self.enabled = enabled if self.enabled: try: from torch.utils.tensorboard.writer import SummaryWriter self.writer = SummaryWriter(log_dir) except ImportError: get_logger(__name__).warning( "TensorBoard not available. Install with: pip install tensorboard" ) self.enabled = False def log_scalar(self, tag: str, value: float, step: int): """Log a scalar value""" if self.enabled: self.writer.add_scalar(tag, value, step) def log_scalars(self, main_tag: str, tag_scalar_dict: Dict[str, float], step: int): """Log multiple scalars""" if self.enabled: self.writer.add_scalars(main_tag, tag_scalar_dict, step) def log_histogram(self, tag: str, values: np.ndarray, step: int): """Log a histogram""" if self.enabled: self.writer.add_histogram(tag, values, step) def log_embedding(self, embeddings: torch.Tensor, metadata: Optional[list] = None, step: int = 0): """Log embeddings for visualization""" if self.enabled: self.writer.add_embedding(embeddings, metadata=metadata, global_step=step) def flush(self): """Flush the writer""" if self.enabled: self.writer.flush() def close(self): """Close the writer""" if self.enabled: self.writer.close() class ProgressLogger: """Helper for logging training progress""" def __init__(self, total_steps: int, log_interval: int = 100): self.total_steps = total_steps self.log_interval = log_interval self.logger = get_logger(__name__) self.start_time = datetime.now() def log_step(self, step: int, loss: float, metrics: Optional[Dict[str, float]] = None): """Log progress for a training step""" if step % self.log_interval == 0: elapsed = (datetime.now() - self.start_time).total_seconds() steps_per_sec = step / elapsed if elapsed > 0 else 0 eta_seconds = (self.total_steps - step) / steps_per_sec if steps_per_sec > 0 else 0 msg = f"Step {step}/{self.total_steps} | Loss: {loss:.4f}" if metrics: metric_str = " | ".join([f"{k}: {v:.4f}" for k, v in metrics.items()]) msg += f" | {metric_str}" msg += f" | Speed: {steps_per_sec:.2f} steps/s | ETA: {self._format_time(eta_seconds)}" self.logger.info(msg) def _format_time(self, seconds: float) -> str: """Format seconds to human readable time""" hours = int(seconds // 3600) minutes = int((seconds % 3600) // 60) seconds = int(seconds % 60) if hours > 0: return f"{hours}h {minutes}m {seconds}s" elif minutes > 0: return f"{minutes}m {seconds}s" else: return f"{seconds}s" def log_model_info(model: torch.nn.Module, logger: Optional[logging.Logger] = None): """Log model information""" if logger is None: logger = get_logger(__name__) # Count parameters total_params = sum(p.numel() for p in model.parameters()) trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) logger.info(f"Model: {model.__class__.__name__}") logger.info(f"Total parameters: {total_params:,}") logger.info(f"Trainable parameters: {trainable_params:,}") logger.info(f"Non-trainable parameters: {total_params - trainable_params:,}") # Log parameter shapes logger.debug("Parameter shapes:") for name, param in model.named_parameters(): logger.debug(f" {name}: {list(param.shape)}") def log_training_summary( config: Dict[str, Any], metrics: Dict[str, float], elapsed_time: float, output_dir: str ): """Log comprehensive training summary""" logger = get_logger(__name__) summary = { 'config': config, 'final_metrics': metrics, 'training_time': elapsed_time, 'timestamp': datetime.now().isoformat() } # Save to file summary_file = os.path.join(output_dir, 'training_summary.json') with open(summary_file, 'w') as f: json.dump(summary, f, indent=2) # Log to console logger.info("=" * 50) logger.info("TRAINING SUMMARY") logger.info("=" * 50) logger.info(f"Training time: {elapsed_time:.2f} seconds") logger.info("Final metrics:") for metric, value in metrics.items(): logger.info(f" {metric}: {value:.4f}") logger.info(f"Summary saved to: {summary_file}") logger.info("=" * 50)