init commit
This commit is contained in:
344
utils/logging.py
Normal file
344
utils/logging.py
Normal file
@@ -0,0 +1,344 @@
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from datetime import datetime
|
||||
from typing import Optional, Dict, Any
|
||||
import json
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
import numpy as np
|
||||
|
||||
|
||||
def create_progress_bar(
|
||||
iterable,
|
||||
desc: str = "",
|
||||
total: Optional[int] = None,
|
||||
disable: bool = False,
|
||||
position: int = 0,
|
||||
leave: bool = True,
|
||||
ncols: Optional[int] = None,
|
||||
unit: str = "it",
|
||||
dynamic_ncols: bool = True,
|
||||
):
|
||||
"""
|
||||
Create a tqdm progress bar with consistent settings across the project.
|
||||
|
||||
Args:
|
||||
iterable: Iterable to wrap with progress bar
|
||||
desc: Description for the progress bar
|
||||
total: Total number of iterations (if None, will try to get from iterable)
|
||||
disable: Whether to disable the progress bar
|
||||
position: Position of the progress bar (for multiple bars)
|
||||
leave: Whether to leave the progress bar after completion
|
||||
ncols: Width of the progress bar
|
||||
unit: Unit of iteration (e.g., "batch", "step")
|
||||
dynamic_ncols: Automatically adjust width based on terminal size
|
||||
|
||||
Returns:
|
||||
tqdm progress bar object
|
||||
"""
|
||||
# Check if we're in distributed training
|
||||
try:
|
||||
import torch.distributed as dist
|
||||
if dist.is_initialized() and dist.get_rank() != 0:
|
||||
# Disable progress bar for non-main processes
|
||||
disable = True
|
||||
except:
|
||||
pass
|
||||
|
||||
# Create progress bar with consistent settings
|
||||
return tqdm(
|
||||
iterable,
|
||||
desc=desc,
|
||||
total=total,
|
||||
disable=disable,
|
||||
position=position,
|
||||
leave=leave,
|
||||
ncols=ncols,
|
||||
unit=unit,
|
||||
dynamic_ncols=dynamic_ncols,
|
||||
# Use custom write method to avoid conflicts with logging
|
||||
file=sys.stdout,
|
||||
# Ensure smooth updates
|
||||
smoothing=0.1,
|
||||
# Format with consistent style
|
||||
bar_format='{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}{postfix}]'
|
||||
)
|
||||
|
||||
|
||||
class TqdmLoggingHandler(logging.Handler):
|
||||
"""Custom logging handler that plays nicely with tqdm progress bars"""
|
||||
|
||||
def emit(self, record):
|
||||
try:
|
||||
msg = self.format(record)
|
||||
# Use tqdm.write to avoid breaking progress bars
|
||||
tqdm.write(msg, file=sys.stdout)
|
||||
self.flush()
|
||||
except Exception:
|
||||
self.handleError(record)
|
||||
|
||||
|
||||
def setup_logging(
|
||||
output_dir: Optional[str] = None,
|
||||
log_to_file: bool = False,
|
||||
log_level: int = logging.INFO,
|
||||
log_to_console: bool = True,
|
||||
use_tqdm_handler: bool = True,
|
||||
rank: int = 0,
|
||||
world_size: int = 1,
|
||||
):
|
||||
"""
|
||||
Setup logging with optional file and console handlers.
|
||||
|
||||
Args:
|
||||
output_dir: Directory for log files
|
||||
log_to_file: Whether to log to file
|
||||
log_level: Logging level
|
||||
log_to_console: Whether to log to console
|
||||
use_tqdm_handler: Use tqdm-compatible handler for console logging
|
||||
rank: Process rank for distributed training
|
||||
world_size: World size for distributed training
|
||||
"""
|
||||
handlers: list[logging.Handler] = []
|
||||
log_path = None
|
||||
|
||||
# Only log to console from main process in distributed training
|
||||
if log_to_console and (world_size == 1 or rank == 0):
|
||||
if use_tqdm_handler:
|
||||
# Use tqdm-compatible handler
|
||||
console_handler = TqdmLoggingHandler()
|
||||
else:
|
||||
console_handler = logging.StreamHandler(sys.stdout)
|
||||
console_handler.setLevel(log_level)
|
||||
handlers.append(console_handler)
|
||||
|
||||
if log_to_file:
|
||||
if output_dir is None:
|
||||
output_dir = "./logs"
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
# Add rank to filename for distributed training
|
||||
if world_size > 1:
|
||||
log_filename = f"train_rank{rank}.log"
|
||||
else:
|
||||
log_filename = "train.log"
|
||||
|
||||
log_path = os.path.join(output_dir, log_filename)
|
||||
file_handler = logging.FileHandler(log_path)
|
||||
file_handler.setLevel(log_level)
|
||||
handlers.append(file_handler)
|
||||
|
||||
# Create formatter
|
||||
if world_size > 1:
|
||||
# Include rank in log format for distributed training
|
||||
formatter = logging.Formatter(
|
||||
f'[%(asctime)s][Rank {rank}] %(levelname)s %(name)s: %(message)s',
|
||||
datefmt='%Y-%m-%d %H:%M:%S'
|
||||
)
|
||||
else:
|
||||
formatter = logging.Formatter(
|
||||
'[%(asctime)s] %(levelname)s %(name)s: %(message)s',
|
||||
datefmt='%Y-%m-%d %H:%M:%S'
|
||||
)
|
||||
|
||||
# Apply formatter to all handlers
|
||||
for handler in handlers:
|
||||
handler.setFormatter(formatter)
|
||||
|
||||
# Configure root logger
|
||||
root_logger = logging.getLogger()
|
||||
root_logger.setLevel(log_level)
|
||||
|
||||
# Remove existing handlers to avoid duplicates
|
||||
for handler in root_logger.handlers[:]:
|
||||
root_logger.removeHandler(handler)
|
||||
|
||||
# Add new handlers
|
||||
for handler in handlers:
|
||||
root_logger.addHandler(handler)
|
||||
|
||||
# Log initialization message
|
||||
if rank == 0:
|
||||
logging.info(
|
||||
f"Logging initialized. Level: {logging.getLevelName(log_level)} | "
|
||||
f"Log file: {log_path if log_to_file else 'none'} | "
|
||||
f"World size: {world_size}"
|
||||
)
|
||||
|
||||
|
||||
def get_logger(name: str) -> logging.Logger:
|
||||
"""Get a logger with the specified name"""
|
||||
return logging.getLogger(name)
|
||||
|
||||
|
||||
class MetricsLogger:
|
||||
"""Logger specifically for training metrics"""
|
||||
|
||||
def __init__(self, output_dir: str):
|
||||
self.output_dir = output_dir
|
||||
self.metrics_file = os.path.join(output_dir, 'metrics.jsonl')
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
def log(self, step: int, metrics: Dict[str, Any], phase: str = 'train'):
|
||||
"""Log metrics to file"""
|
||||
entry = {
|
||||
'step': step,
|
||||
'phase': phase,
|
||||
'timestamp': datetime.now().isoformat(),
|
||||
'metrics': metrics
|
||||
}
|
||||
|
||||
with open(self.metrics_file, 'a') as f:
|
||||
f.write(json.dumps(entry) + '\n')
|
||||
|
||||
def log_config(self, config: Dict[str, Any]):
|
||||
"""Log configuration"""
|
||||
config_file = os.path.join(self.output_dir, 'config.json')
|
||||
with open(config_file, 'w') as f:
|
||||
json.dump(config, f, indent=2)
|
||||
|
||||
def log_summary(self, summary: Dict[str, Any]):
|
||||
"""Log training summary"""
|
||||
summary_file = os.path.join(self.output_dir, 'summary.json')
|
||||
with open(summary_file, 'w') as f:
|
||||
json.dump(summary, f, indent=2)
|
||||
|
||||
|
||||
class TensorBoardLogger:
|
||||
"""TensorBoard logging wrapper"""
|
||||
|
||||
def __init__(self, log_dir: str, enabled: bool = True):
|
||||
self.enabled = enabled
|
||||
if self.enabled:
|
||||
try:
|
||||
from torch.utils.tensorboard.writer import SummaryWriter
|
||||
self.writer = SummaryWriter(log_dir)
|
||||
except ImportError:
|
||||
get_logger(__name__).warning(
|
||||
"TensorBoard not available. Install with: pip install tensorboard"
|
||||
)
|
||||
self.enabled = False
|
||||
|
||||
def log_scalar(self, tag: str, value: float, step: int):
|
||||
"""Log a scalar value"""
|
||||
if self.enabled:
|
||||
self.writer.add_scalar(tag, value, step)
|
||||
|
||||
def log_scalars(self, main_tag: str, tag_scalar_dict: Dict[str, float], step: int):
|
||||
"""Log multiple scalars"""
|
||||
if self.enabled:
|
||||
self.writer.add_scalars(main_tag, tag_scalar_dict, step)
|
||||
|
||||
def log_histogram(self, tag: str, values: np.ndarray, step: int):
|
||||
"""Log a histogram"""
|
||||
if self.enabled:
|
||||
self.writer.add_histogram(tag, values, step)
|
||||
|
||||
def log_embedding(self, embeddings: torch.Tensor, metadata: Optional[list] = None, step: int = 0):
|
||||
"""Log embeddings for visualization"""
|
||||
if self.enabled:
|
||||
self.writer.add_embedding(embeddings, metadata=metadata, global_step=step)
|
||||
|
||||
def flush(self):
|
||||
"""Flush the writer"""
|
||||
if self.enabled:
|
||||
self.writer.flush()
|
||||
|
||||
def close(self):
|
||||
"""Close the writer"""
|
||||
if self.enabled:
|
||||
self.writer.close()
|
||||
|
||||
|
||||
class ProgressLogger:
|
||||
"""Helper for logging training progress"""
|
||||
|
||||
def __init__(self, total_steps: int, log_interval: int = 100):
|
||||
self.total_steps = total_steps
|
||||
self.log_interval = log_interval
|
||||
self.logger = get_logger(__name__)
|
||||
self.start_time = datetime.now()
|
||||
|
||||
def log_step(self, step: int, loss: float, metrics: Optional[Dict[str, float]] = None):
|
||||
"""Log progress for a training step"""
|
||||
if step % self.log_interval == 0:
|
||||
elapsed = (datetime.now() - self.start_time).total_seconds()
|
||||
steps_per_sec = step / elapsed if elapsed > 0 else 0
|
||||
eta_seconds = (self.total_steps - step) / steps_per_sec if steps_per_sec > 0 else 0
|
||||
|
||||
msg = f"Step {step}/{self.total_steps} | Loss: {loss:.4f}"
|
||||
|
||||
if metrics:
|
||||
metric_str = " | ".join([f"{k}: {v:.4f}" for k, v in metrics.items()])
|
||||
msg += f" | {metric_str}"
|
||||
|
||||
msg += f" | Speed: {steps_per_sec:.2f} steps/s | ETA: {self._format_time(eta_seconds)}"
|
||||
|
||||
self.logger.info(msg)
|
||||
|
||||
def _format_time(self, seconds: float) -> str:
|
||||
"""Format seconds to human readable time"""
|
||||
hours = int(seconds // 3600)
|
||||
minutes = int((seconds % 3600) // 60)
|
||||
seconds = int(seconds % 60)
|
||||
|
||||
if hours > 0:
|
||||
return f"{hours}h {minutes}m {seconds}s"
|
||||
elif minutes > 0:
|
||||
return f"{minutes}m {seconds}s"
|
||||
else:
|
||||
return f"{seconds}s"
|
||||
|
||||
|
||||
def log_model_info(model: torch.nn.Module, logger: Optional[logging.Logger] = None):
|
||||
"""Log model information"""
|
||||
if logger is None:
|
||||
logger = get_logger(__name__)
|
||||
|
||||
# Count parameters
|
||||
total_params = sum(p.numel() for p in model.parameters())
|
||||
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||
|
||||
logger.info(f"Model: {model.__class__.__name__}")
|
||||
logger.info(f"Total parameters: {total_params:,}")
|
||||
logger.info(f"Trainable parameters: {trainable_params:,}")
|
||||
logger.info(f"Non-trainable parameters: {total_params - trainable_params:,}")
|
||||
|
||||
# Log parameter shapes
|
||||
logger.debug("Parameter shapes:")
|
||||
for name, param in model.named_parameters():
|
||||
logger.debug(f" {name}: {list(param.shape)}")
|
||||
|
||||
|
||||
def log_training_summary(
|
||||
config: Dict[str, Any],
|
||||
metrics: Dict[str, float],
|
||||
elapsed_time: float,
|
||||
output_dir: str
|
||||
):
|
||||
"""Log comprehensive training summary"""
|
||||
logger = get_logger(__name__)
|
||||
|
||||
summary = {
|
||||
'config': config,
|
||||
'final_metrics': metrics,
|
||||
'training_time': elapsed_time,
|
||||
'timestamp': datetime.now().isoformat()
|
||||
}
|
||||
|
||||
# Save to file
|
||||
summary_file = os.path.join(output_dir, 'training_summary.json')
|
||||
with open(summary_file, 'w') as f:
|
||||
json.dump(summary, f, indent=2)
|
||||
|
||||
# Log to console
|
||||
logger.info("=" * 50)
|
||||
logger.info("TRAINING SUMMARY")
|
||||
logger.info("=" * 50)
|
||||
logger.info(f"Training time: {elapsed_time:.2f} seconds")
|
||||
logger.info("Final metrics:")
|
||||
for metric, value in metrics.items():
|
||||
logger.info(f" {metric}: {value:.4f}")
|
||||
logger.info(f"Summary saved to: {summary_file}")
|
||||
logger.info("=" * 50)
|
||||
Reference in New Issue
Block a user