345 lines
11 KiB
Python
345 lines
11 KiB
Python
import logging
|
|
import os
|
|
import sys
|
|
from datetime import datetime
|
|
from typing import Optional, Dict, Any
|
|
import json
|
|
import torch
|
|
from tqdm import tqdm
|
|
import numpy as np
|
|
|
|
|
|
def create_progress_bar(
|
|
iterable,
|
|
desc: str = "",
|
|
total: Optional[int] = None,
|
|
disable: bool = False,
|
|
position: int = 0,
|
|
leave: bool = True,
|
|
ncols: Optional[int] = None,
|
|
unit: str = "it",
|
|
dynamic_ncols: bool = True,
|
|
):
|
|
"""
|
|
Create a tqdm progress bar with consistent settings across the project.
|
|
|
|
Args:
|
|
iterable: Iterable to wrap with progress bar
|
|
desc: Description for the progress bar
|
|
total: Total number of iterations (if None, will try to get from iterable)
|
|
disable: Whether to disable the progress bar
|
|
position: Position of the progress bar (for multiple bars)
|
|
leave: Whether to leave the progress bar after completion
|
|
ncols: Width of the progress bar
|
|
unit: Unit of iteration (e.g., "batch", "step")
|
|
dynamic_ncols: Automatically adjust width based on terminal size
|
|
|
|
Returns:
|
|
tqdm progress bar object
|
|
"""
|
|
# Check if we're in distributed training
|
|
try:
|
|
import torch.distributed as dist
|
|
if dist.is_initialized() and dist.get_rank() != 0:
|
|
# Disable progress bar for non-main processes
|
|
disable = True
|
|
except:
|
|
pass
|
|
|
|
# Create progress bar with consistent settings
|
|
return tqdm(
|
|
iterable,
|
|
desc=desc,
|
|
total=total,
|
|
disable=disable,
|
|
position=position,
|
|
leave=leave,
|
|
ncols=ncols,
|
|
unit=unit,
|
|
dynamic_ncols=dynamic_ncols,
|
|
# Use custom write method to avoid conflicts with logging
|
|
file=sys.stdout,
|
|
# Ensure smooth updates
|
|
smoothing=0.1,
|
|
# Format with consistent style
|
|
bar_format='{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}{postfix}]'
|
|
)
|
|
|
|
|
|
class TqdmLoggingHandler(logging.Handler):
|
|
"""Custom logging handler that plays nicely with tqdm progress bars"""
|
|
|
|
def emit(self, record):
|
|
try:
|
|
msg = self.format(record)
|
|
# Use tqdm.write to avoid breaking progress bars
|
|
tqdm.write(msg, file=sys.stdout)
|
|
self.flush()
|
|
except Exception:
|
|
self.handleError(record)
|
|
|
|
|
|
def setup_logging(
|
|
output_dir: Optional[str] = None,
|
|
log_to_file: bool = False,
|
|
log_level: int = logging.INFO,
|
|
log_to_console: bool = True,
|
|
use_tqdm_handler: bool = True,
|
|
rank: int = 0,
|
|
world_size: int = 1,
|
|
):
|
|
"""
|
|
Setup logging with optional file and console handlers.
|
|
|
|
Args:
|
|
output_dir: Directory for log files
|
|
log_to_file: Whether to log to file
|
|
log_level: Logging level
|
|
log_to_console: Whether to log to console
|
|
use_tqdm_handler: Use tqdm-compatible handler for console logging
|
|
rank: Process rank for distributed training
|
|
world_size: World size for distributed training
|
|
"""
|
|
handlers: list[logging.Handler] = []
|
|
log_path = None
|
|
|
|
# Only log to console from main process in distributed training
|
|
if log_to_console and (world_size == 1 or rank == 0):
|
|
if use_tqdm_handler:
|
|
# Use tqdm-compatible handler
|
|
console_handler = TqdmLoggingHandler()
|
|
else:
|
|
console_handler = logging.StreamHandler(sys.stdout)
|
|
console_handler.setLevel(log_level)
|
|
handlers.append(console_handler)
|
|
|
|
if log_to_file:
|
|
if output_dir is None:
|
|
output_dir = "./logs"
|
|
os.makedirs(output_dir, exist_ok=True)
|
|
|
|
# Add rank to filename for distributed training
|
|
if world_size > 1:
|
|
log_filename = f"train_rank{rank}.log"
|
|
else:
|
|
log_filename = "train.log"
|
|
|
|
log_path = os.path.join(output_dir, log_filename)
|
|
file_handler = logging.FileHandler(log_path)
|
|
file_handler.setLevel(log_level)
|
|
handlers.append(file_handler)
|
|
|
|
# Create formatter
|
|
if world_size > 1:
|
|
# Include rank in log format for distributed training
|
|
formatter = logging.Formatter(
|
|
f'[%(asctime)s][Rank {rank}] %(levelname)s %(name)s: %(message)s',
|
|
datefmt='%Y-%m-%d %H:%M:%S'
|
|
)
|
|
else:
|
|
formatter = logging.Formatter(
|
|
'[%(asctime)s] %(levelname)s %(name)s: %(message)s',
|
|
datefmt='%Y-%m-%d %H:%M:%S'
|
|
)
|
|
|
|
# Apply formatter to all handlers
|
|
for handler in handlers:
|
|
handler.setFormatter(formatter)
|
|
|
|
# Configure root logger
|
|
root_logger = logging.getLogger()
|
|
root_logger.setLevel(log_level)
|
|
|
|
# Remove existing handlers to avoid duplicates
|
|
for handler in root_logger.handlers[:]:
|
|
root_logger.removeHandler(handler)
|
|
|
|
# Add new handlers
|
|
for handler in handlers:
|
|
root_logger.addHandler(handler)
|
|
|
|
# Log initialization message
|
|
if rank == 0:
|
|
logging.info(
|
|
f"Logging initialized. Level: {logging.getLevelName(log_level)} | "
|
|
f"Log file: {log_path if log_to_file else 'none'} | "
|
|
f"World size: {world_size}"
|
|
)
|
|
|
|
|
|
def get_logger(name: str) -> logging.Logger:
|
|
"""Get a logger with the specified name"""
|
|
return logging.getLogger(name)
|
|
|
|
|
|
class MetricsLogger:
|
|
"""Logger specifically for training metrics"""
|
|
|
|
def __init__(self, output_dir: str):
|
|
self.output_dir = output_dir
|
|
self.metrics_file = os.path.join(output_dir, 'metrics.jsonl')
|
|
os.makedirs(output_dir, exist_ok=True)
|
|
|
|
def log(self, step: int, metrics: Dict[str, Any], phase: str = 'train'):
|
|
"""Log metrics to file"""
|
|
entry = {
|
|
'step': step,
|
|
'phase': phase,
|
|
'timestamp': datetime.now().isoformat(),
|
|
'metrics': metrics
|
|
}
|
|
|
|
with open(self.metrics_file, 'a') as f:
|
|
f.write(json.dumps(entry) + '\n')
|
|
|
|
def log_config(self, config: Dict[str, Any]):
|
|
"""Log configuration"""
|
|
config_file = os.path.join(self.output_dir, 'config.json')
|
|
with open(config_file, 'w') as f:
|
|
json.dump(config, f, indent=2)
|
|
|
|
def log_summary(self, summary: Dict[str, Any]):
|
|
"""Log training summary"""
|
|
summary_file = os.path.join(self.output_dir, 'summary.json')
|
|
with open(summary_file, 'w') as f:
|
|
json.dump(summary, f, indent=2)
|
|
|
|
|
|
class TensorBoardLogger:
|
|
"""TensorBoard logging wrapper"""
|
|
|
|
def __init__(self, log_dir: str, enabled: bool = True):
|
|
self.enabled = enabled
|
|
if self.enabled:
|
|
try:
|
|
from torch.utils.tensorboard.writer import SummaryWriter
|
|
self.writer = SummaryWriter(log_dir)
|
|
except ImportError:
|
|
get_logger(__name__).warning(
|
|
"TensorBoard not available. Install with: pip install tensorboard"
|
|
)
|
|
self.enabled = False
|
|
|
|
def log_scalar(self, tag: str, value: float, step: int):
|
|
"""Log a scalar value"""
|
|
if self.enabled:
|
|
self.writer.add_scalar(tag, value, step)
|
|
|
|
def log_scalars(self, main_tag: str, tag_scalar_dict: Dict[str, float], step: int):
|
|
"""Log multiple scalars"""
|
|
if self.enabled:
|
|
self.writer.add_scalars(main_tag, tag_scalar_dict, step)
|
|
|
|
def log_histogram(self, tag: str, values: np.ndarray, step: int):
|
|
"""Log a histogram"""
|
|
if self.enabled:
|
|
self.writer.add_histogram(tag, values, step)
|
|
|
|
def log_embedding(self, embeddings: torch.Tensor, metadata: Optional[list] = None, step: int = 0):
|
|
"""Log embeddings for visualization"""
|
|
if self.enabled:
|
|
self.writer.add_embedding(embeddings, metadata=metadata, global_step=step)
|
|
|
|
def flush(self):
|
|
"""Flush the writer"""
|
|
if self.enabled:
|
|
self.writer.flush()
|
|
|
|
def close(self):
|
|
"""Close the writer"""
|
|
if self.enabled:
|
|
self.writer.close()
|
|
|
|
|
|
class ProgressLogger:
|
|
"""Helper for logging training progress"""
|
|
|
|
def __init__(self, total_steps: int, log_interval: int = 100):
|
|
self.total_steps = total_steps
|
|
self.log_interval = log_interval
|
|
self.logger = get_logger(__name__)
|
|
self.start_time = datetime.now()
|
|
|
|
def log_step(self, step: int, loss: float, metrics: Optional[Dict[str, float]] = None):
|
|
"""Log progress for a training step"""
|
|
if step % self.log_interval == 0:
|
|
elapsed = (datetime.now() - self.start_time).total_seconds()
|
|
steps_per_sec = step / elapsed if elapsed > 0 else 0
|
|
eta_seconds = (self.total_steps - step) / steps_per_sec if steps_per_sec > 0 else 0
|
|
|
|
msg = f"Step {step}/{self.total_steps} | Loss: {loss:.4f}"
|
|
|
|
if metrics:
|
|
metric_str = " | ".join([f"{k}: {v:.4f}" for k, v in metrics.items()])
|
|
msg += f" | {metric_str}"
|
|
|
|
msg += f" | Speed: {steps_per_sec:.2f} steps/s | ETA: {self._format_time(eta_seconds)}"
|
|
|
|
self.logger.info(msg)
|
|
|
|
def _format_time(self, seconds: float) -> str:
|
|
"""Format seconds to human readable time"""
|
|
hours = int(seconds // 3600)
|
|
minutes = int((seconds % 3600) // 60)
|
|
seconds = int(seconds % 60)
|
|
|
|
if hours > 0:
|
|
return f"{hours}h {minutes}m {seconds}s"
|
|
elif minutes > 0:
|
|
return f"{minutes}m {seconds}s"
|
|
else:
|
|
return f"{seconds}s"
|
|
|
|
|
|
def log_model_info(model: torch.nn.Module, logger: Optional[logging.Logger] = None):
|
|
"""Log model information"""
|
|
if logger is None:
|
|
logger = get_logger(__name__)
|
|
|
|
# Count parameters
|
|
total_params = sum(p.numel() for p in model.parameters())
|
|
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
|
|
|
logger.info(f"Model: {model.__class__.__name__}")
|
|
logger.info(f"Total parameters: {total_params:,}")
|
|
logger.info(f"Trainable parameters: {trainable_params:,}")
|
|
logger.info(f"Non-trainable parameters: {total_params - trainable_params:,}")
|
|
|
|
# Log parameter shapes
|
|
logger.debug("Parameter shapes:")
|
|
for name, param in model.named_parameters():
|
|
logger.debug(f" {name}: {list(param.shape)}")
|
|
|
|
|
|
def log_training_summary(
|
|
config: Dict[str, Any],
|
|
metrics: Dict[str, float],
|
|
elapsed_time: float,
|
|
output_dir: str
|
|
):
|
|
"""Log comprehensive training summary"""
|
|
logger = get_logger(__name__)
|
|
|
|
summary = {
|
|
'config': config,
|
|
'final_metrics': metrics,
|
|
'training_time': elapsed_time,
|
|
'timestamp': datetime.now().isoformat()
|
|
}
|
|
|
|
# Save to file
|
|
summary_file = os.path.join(output_dir, 'training_summary.json')
|
|
with open(summary_file, 'w') as f:
|
|
json.dump(summary, f, indent=2)
|
|
|
|
# Log to console
|
|
logger.info("=" * 50)
|
|
logger.info("TRAINING SUMMARY")
|
|
logger.info("=" * 50)
|
|
logger.info(f"Training time: {elapsed_time:.2f} seconds")
|
|
logger.info("Final metrics:")
|
|
for metric, value in metrics.items():
|
|
logger.info(f" {metric}: {value:.4f}")
|
|
logger.info(f"Summary saved to: {summary_file}")
|
|
logger.info("=" * 50)
|