init commit

This commit is contained in:
ldy
2025-07-22 16:55:25 +08:00
commit 36003b83e2
67 changed files with 76613 additions and 0 deletions

344
utils/logging.py Normal file
View File

@@ -0,0 +1,344 @@
import logging
import os
import sys
from datetime import datetime
from typing import Optional, Dict, Any
import json
import torch
from tqdm import tqdm
import numpy as np
def create_progress_bar(
iterable,
desc: str = "",
total: Optional[int] = None,
disable: bool = False,
position: int = 0,
leave: bool = True,
ncols: Optional[int] = None,
unit: str = "it",
dynamic_ncols: bool = True,
):
"""
Create a tqdm progress bar with consistent settings across the project.
Args:
iterable: Iterable to wrap with progress bar
desc: Description for the progress bar
total: Total number of iterations (if None, will try to get from iterable)
disable: Whether to disable the progress bar
position: Position of the progress bar (for multiple bars)
leave: Whether to leave the progress bar after completion
ncols: Width of the progress bar
unit: Unit of iteration (e.g., "batch", "step")
dynamic_ncols: Automatically adjust width based on terminal size
Returns:
tqdm progress bar object
"""
# Check if we're in distributed training
try:
import torch.distributed as dist
if dist.is_initialized() and dist.get_rank() != 0:
# Disable progress bar for non-main processes
disable = True
except:
pass
# Create progress bar with consistent settings
return tqdm(
iterable,
desc=desc,
total=total,
disable=disable,
position=position,
leave=leave,
ncols=ncols,
unit=unit,
dynamic_ncols=dynamic_ncols,
# Use custom write method to avoid conflicts with logging
file=sys.stdout,
# Ensure smooth updates
smoothing=0.1,
# Format with consistent style
bar_format='{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}{postfix}]'
)
class TqdmLoggingHandler(logging.Handler):
"""Custom logging handler that plays nicely with tqdm progress bars"""
def emit(self, record):
try:
msg = self.format(record)
# Use tqdm.write to avoid breaking progress bars
tqdm.write(msg, file=sys.stdout)
self.flush()
except Exception:
self.handleError(record)
def setup_logging(
output_dir: Optional[str] = None,
log_to_file: bool = False,
log_level: int = logging.INFO,
log_to_console: bool = True,
use_tqdm_handler: bool = True,
rank: int = 0,
world_size: int = 1,
):
"""
Setup logging with optional file and console handlers.
Args:
output_dir: Directory for log files
log_to_file: Whether to log to file
log_level: Logging level
log_to_console: Whether to log to console
use_tqdm_handler: Use tqdm-compatible handler for console logging
rank: Process rank for distributed training
world_size: World size for distributed training
"""
handlers: list[logging.Handler] = []
log_path = None
# Only log to console from main process in distributed training
if log_to_console and (world_size == 1 or rank == 0):
if use_tqdm_handler:
# Use tqdm-compatible handler
console_handler = TqdmLoggingHandler()
else:
console_handler = logging.StreamHandler(sys.stdout)
console_handler.setLevel(log_level)
handlers.append(console_handler)
if log_to_file:
if output_dir is None:
output_dir = "./logs"
os.makedirs(output_dir, exist_ok=True)
# Add rank to filename for distributed training
if world_size > 1:
log_filename = f"train_rank{rank}.log"
else:
log_filename = "train.log"
log_path = os.path.join(output_dir, log_filename)
file_handler = logging.FileHandler(log_path)
file_handler.setLevel(log_level)
handlers.append(file_handler)
# Create formatter
if world_size > 1:
# Include rank in log format for distributed training
formatter = logging.Formatter(
f'[%(asctime)s][Rank {rank}] %(levelname)s %(name)s: %(message)s',
datefmt='%Y-%m-%d %H:%M:%S'
)
else:
formatter = logging.Formatter(
'[%(asctime)s] %(levelname)s %(name)s: %(message)s',
datefmt='%Y-%m-%d %H:%M:%S'
)
# Apply formatter to all handlers
for handler in handlers:
handler.setFormatter(formatter)
# Configure root logger
root_logger = logging.getLogger()
root_logger.setLevel(log_level)
# Remove existing handlers to avoid duplicates
for handler in root_logger.handlers[:]:
root_logger.removeHandler(handler)
# Add new handlers
for handler in handlers:
root_logger.addHandler(handler)
# Log initialization message
if rank == 0:
logging.info(
f"Logging initialized. Level: {logging.getLevelName(log_level)} | "
f"Log file: {log_path if log_to_file else 'none'} | "
f"World size: {world_size}"
)
def get_logger(name: str) -> logging.Logger:
"""Get a logger with the specified name"""
return logging.getLogger(name)
class MetricsLogger:
"""Logger specifically for training metrics"""
def __init__(self, output_dir: str):
self.output_dir = output_dir
self.metrics_file = os.path.join(output_dir, 'metrics.jsonl')
os.makedirs(output_dir, exist_ok=True)
def log(self, step: int, metrics: Dict[str, Any], phase: str = 'train'):
"""Log metrics to file"""
entry = {
'step': step,
'phase': phase,
'timestamp': datetime.now().isoformat(),
'metrics': metrics
}
with open(self.metrics_file, 'a') as f:
f.write(json.dumps(entry) + '\n')
def log_config(self, config: Dict[str, Any]):
"""Log configuration"""
config_file = os.path.join(self.output_dir, 'config.json')
with open(config_file, 'w') as f:
json.dump(config, f, indent=2)
def log_summary(self, summary: Dict[str, Any]):
"""Log training summary"""
summary_file = os.path.join(self.output_dir, 'summary.json')
with open(summary_file, 'w') as f:
json.dump(summary, f, indent=2)
class TensorBoardLogger:
"""TensorBoard logging wrapper"""
def __init__(self, log_dir: str, enabled: bool = True):
self.enabled = enabled
if self.enabled:
try:
from torch.utils.tensorboard.writer import SummaryWriter
self.writer = SummaryWriter(log_dir)
except ImportError:
get_logger(__name__).warning(
"TensorBoard not available. Install with: pip install tensorboard"
)
self.enabled = False
def log_scalar(self, tag: str, value: float, step: int):
"""Log a scalar value"""
if self.enabled:
self.writer.add_scalar(tag, value, step)
def log_scalars(self, main_tag: str, tag_scalar_dict: Dict[str, float], step: int):
"""Log multiple scalars"""
if self.enabled:
self.writer.add_scalars(main_tag, tag_scalar_dict, step)
def log_histogram(self, tag: str, values: np.ndarray, step: int):
"""Log a histogram"""
if self.enabled:
self.writer.add_histogram(tag, values, step)
def log_embedding(self, embeddings: torch.Tensor, metadata: Optional[list] = None, step: int = 0):
"""Log embeddings for visualization"""
if self.enabled:
self.writer.add_embedding(embeddings, metadata=metadata, global_step=step)
def flush(self):
"""Flush the writer"""
if self.enabled:
self.writer.flush()
def close(self):
"""Close the writer"""
if self.enabled:
self.writer.close()
class ProgressLogger:
"""Helper for logging training progress"""
def __init__(self, total_steps: int, log_interval: int = 100):
self.total_steps = total_steps
self.log_interval = log_interval
self.logger = get_logger(__name__)
self.start_time = datetime.now()
def log_step(self, step: int, loss: float, metrics: Optional[Dict[str, float]] = None):
"""Log progress for a training step"""
if step % self.log_interval == 0:
elapsed = (datetime.now() - self.start_time).total_seconds()
steps_per_sec = step / elapsed if elapsed > 0 else 0
eta_seconds = (self.total_steps - step) / steps_per_sec if steps_per_sec > 0 else 0
msg = f"Step {step}/{self.total_steps} | Loss: {loss:.4f}"
if metrics:
metric_str = " | ".join([f"{k}: {v:.4f}" for k, v in metrics.items()])
msg += f" | {metric_str}"
msg += f" | Speed: {steps_per_sec:.2f} steps/s | ETA: {self._format_time(eta_seconds)}"
self.logger.info(msg)
def _format_time(self, seconds: float) -> str:
"""Format seconds to human readable time"""
hours = int(seconds // 3600)
minutes = int((seconds % 3600) // 60)
seconds = int(seconds % 60)
if hours > 0:
return f"{hours}h {minutes}m {seconds}s"
elif minutes > 0:
return f"{minutes}m {seconds}s"
else:
return f"{seconds}s"
def log_model_info(model: torch.nn.Module, logger: Optional[logging.Logger] = None):
"""Log model information"""
if logger is None:
logger = get_logger(__name__)
# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
logger.info(f"Model: {model.__class__.__name__}")
logger.info(f"Total parameters: {total_params:,}")
logger.info(f"Trainable parameters: {trainable_params:,}")
logger.info(f"Non-trainable parameters: {total_params - trainable_params:,}")
# Log parameter shapes
logger.debug("Parameter shapes:")
for name, param in model.named_parameters():
logger.debug(f" {name}: {list(param.shape)}")
def log_training_summary(
config: Dict[str, Any],
metrics: Dict[str, float],
elapsed_time: float,
output_dir: str
):
"""Log comprehensive training summary"""
logger = get_logger(__name__)
summary = {
'config': config,
'final_metrics': metrics,
'training_time': elapsed_time,
'timestamp': datetime.now().isoformat()
}
# Save to file
summary_file = os.path.join(output_dir, 'training_summary.json')
with open(summary_file, 'w') as f:
json.dump(summary, f, indent=2)
# Log to console
logger.info("=" * 50)
logger.info("TRAINING SUMMARY")
logger.info("=" * 50)
logger.info(f"Training time: {elapsed_time:.2f} seconds")
logger.info("Final metrics:")
for metric, value in metrics.items():
logger.info(f" {metric}: {value:.4f}")
logger.info(f"Summary saved to: {summary_file}")
logger.info("=" * 50)