bge_finetune/utils/logging.py
2025-07-22 16:55:25 +08:00

345 lines
11 KiB
Python

import logging
import os
import sys
from datetime import datetime
from typing import Optional, Dict, Any
import json
import torch
from tqdm import tqdm
import numpy as np
def create_progress_bar(
iterable,
desc: str = "",
total: Optional[int] = None,
disable: bool = False,
position: int = 0,
leave: bool = True,
ncols: Optional[int] = None,
unit: str = "it",
dynamic_ncols: bool = True,
):
"""
Create a tqdm progress bar with consistent settings across the project.
Args:
iterable: Iterable to wrap with progress bar
desc: Description for the progress bar
total: Total number of iterations (if None, will try to get from iterable)
disable: Whether to disable the progress bar
position: Position of the progress bar (for multiple bars)
leave: Whether to leave the progress bar after completion
ncols: Width of the progress bar
unit: Unit of iteration (e.g., "batch", "step")
dynamic_ncols: Automatically adjust width based on terminal size
Returns:
tqdm progress bar object
"""
# Check if we're in distributed training
try:
import torch.distributed as dist
if dist.is_initialized() and dist.get_rank() != 0:
# Disable progress bar for non-main processes
disable = True
except:
pass
# Create progress bar with consistent settings
return tqdm(
iterable,
desc=desc,
total=total,
disable=disable,
position=position,
leave=leave,
ncols=ncols,
unit=unit,
dynamic_ncols=dynamic_ncols,
# Use custom write method to avoid conflicts with logging
file=sys.stdout,
# Ensure smooth updates
smoothing=0.1,
# Format with consistent style
bar_format='{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}{postfix}]'
)
class TqdmLoggingHandler(logging.Handler):
"""Custom logging handler that plays nicely with tqdm progress bars"""
def emit(self, record):
try:
msg = self.format(record)
# Use tqdm.write to avoid breaking progress bars
tqdm.write(msg, file=sys.stdout)
self.flush()
except Exception:
self.handleError(record)
def setup_logging(
output_dir: Optional[str] = None,
log_to_file: bool = False,
log_level: int = logging.INFO,
log_to_console: bool = True,
use_tqdm_handler: bool = True,
rank: int = 0,
world_size: int = 1,
):
"""
Setup logging with optional file and console handlers.
Args:
output_dir: Directory for log files
log_to_file: Whether to log to file
log_level: Logging level
log_to_console: Whether to log to console
use_tqdm_handler: Use tqdm-compatible handler for console logging
rank: Process rank for distributed training
world_size: World size for distributed training
"""
handlers: list[logging.Handler] = []
log_path = None
# Only log to console from main process in distributed training
if log_to_console and (world_size == 1 or rank == 0):
if use_tqdm_handler:
# Use tqdm-compatible handler
console_handler = TqdmLoggingHandler()
else:
console_handler = logging.StreamHandler(sys.stdout)
console_handler.setLevel(log_level)
handlers.append(console_handler)
if log_to_file:
if output_dir is None:
output_dir = "./logs"
os.makedirs(output_dir, exist_ok=True)
# Add rank to filename for distributed training
if world_size > 1:
log_filename = f"train_rank{rank}.log"
else:
log_filename = "train.log"
log_path = os.path.join(output_dir, log_filename)
file_handler = logging.FileHandler(log_path)
file_handler.setLevel(log_level)
handlers.append(file_handler)
# Create formatter
if world_size > 1:
# Include rank in log format for distributed training
formatter = logging.Formatter(
f'[%(asctime)s][Rank {rank}] %(levelname)s %(name)s: %(message)s',
datefmt='%Y-%m-%d %H:%M:%S'
)
else:
formatter = logging.Formatter(
'[%(asctime)s] %(levelname)s %(name)s: %(message)s',
datefmt='%Y-%m-%d %H:%M:%S'
)
# Apply formatter to all handlers
for handler in handlers:
handler.setFormatter(formatter)
# Configure root logger
root_logger = logging.getLogger()
root_logger.setLevel(log_level)
# Remove existing handlers to avoid duplicates
for handler in root_logger.handlers[:]:
root_logger.removeHandler(handler)
# Add new handlers
for handler in handlers:
root_logger.addHandler(handler)
# Log initialization message
if rank == 0:
logging.info(
f"Logging initialized. Level: {logging.getLevelName(log_level)} | "
f"Log file: {log_path if log_to_file else 'none'} | "
f"World size: {world_size}"
)
def get_logger(name: str) -> logging.Logger:
"""Get a logger with the specified name"""
return logging.getLogger(name)
class MetricsLogger:
"""Logger specifically for training metrics"""
def __init__(self, output_dir: str):
self.output_dir = output_dir
self.metrics_file = os.path.join(output_dir, 'metrics.jsonl')
os.makedirs(output_dir, exist_ok=True)
def log(self, step: int, metrics: Dict[str, Any], phase: str = 'train'):
"""Log metrics to file"""
entry = {
'step': step,
'phase': phase,
'timestamp': datetime.now().isoformat(),
'metrics': metrics
}
with open(self.metrics_file, 'a') as f:
f.write(json.dumps(entry) + '\n')
def log_config(self, config: Dict[str, Any]):
"""Log configuration"""
config_file = os.path.join(self.output_dir, 'config.json')
with open(config_file, 'w') as f:
json.dump(config, f, indent=2)
def log_summary(self, summary: Dict[str, Any]):
"""Log training summary"""
summary_file = os.path.join(self.output_dir, 'summary.json')
with open(summary_file, 'w') as f:
json.dump(summary, f, indent=2)
class TensorBoardLogger:
"""TensorBoard logging wrapper"""
def __init__(self, log_dir: str, enabled: bool = True):
self.enabled = enabled
if self.enabled:
try:
from torch.utils.tensorboard.writer import SummaryWriter
self.writer = SummaryWriter(log_dir)
except ImportError:
get_logger(__name__).warning(
"TensorBoard not available. Install with: pip install tensorboard"
)
self.enabled = False
def log_scalar(self, tag: str, value: float, step: int):
"""Log a scalar value"""
if self.enabled:
self.writer.add_scalar(tag, value, step)
def log_scalars(self, main_tag: str, tag_scalar_dict: Dict[str, float], step: int):
"""Log multiple scalars"""
if self.enabled:
self.writer.add_scalars(main_tag, tag_scalar_dict, step)
def log_histogram(self, tag: str, values: np.ndarray, step: int):
"""Log a histogram"""
if self.enabled:
self.writer.add_histogram(tag, values, step)
def log_embedding(self, embeddings: torch.Tensor, metadata: Optional[list] = None, step: int = 0):
"""Log embeddings for visualization"""
if self.enabled:
self.writer.add_embedding(embeddings, metadata=metadata, global_step=step)
def flush(self):
"""Flush the writer"""
if self.enabled:
self.writer.flush()
def close(self):
"""Close the writer"""
if self.enabled:
self.writer.close()
class ProgressLogger:
"""Helper for logging training progress"""
def __init__(self, total_steps: int, log_interval: int = 100):
self.total_steps = total_steps
self.log_interval = log_interval
self.logger = get_logger(__name__)
self.start_time = datetime.now()
def log_step(self, step: int, loss: float, metrics: Optional[Dict[str, float]] = None):
"""Log progress for a training step"""
if step % self.log_interval == 0:
elapsed = (datetime.now() - self.start_time).total_seconds()
steps_per_sec = step / elapsed if elapsed > 0 else 0
eta_seconds = (self.total_steps - step) / steps_per_sec if steps_per_sec > 0 else 0
msg = f"Step {step}/{self.total_steps} | Loss: {loss:.4f}"
if metrics:
metric_str = " | ".join([f"{k}: {v:.4f}" for k, v in metrics.items()])
msg += f" | {metric_str}"
msg += f" | Speed: {steps_per_sec:.2f} steps/s | ETA: {self._format_time(eta_seconds)}"
self.logger.info(msg)
def _format_time(self, seconds: float) -> str:
"""Format seconds to human readable time"""
hours = int(seconds // 3600)
minutes = int((seconds % 3600) // 60)
seconds = int(seconds % 60)
if hours > 0:
return f"{hours}h {minutes}m {seconds}s"
elif minutes > 0:
return f"{minutes}m {seconds}s"
else:
return f"{seconds}s"
def log_model_info(model: torch.nn.Module, logger: Optional[logging.Logger] = None):
"""Log model information"""
if logger is None:
logger = get_logger(__name__)
# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
logger.info(f"Model: {model.__class__.__name__}")
logger.info(f"Total parameters: {total_params:,}")
logger.info(f"Trainable parameters: {trainable_params:,}")
logger.info(f"Non-trainable parameters: {total_params - trainable_params:,}")
# Log parameter shapes
logger.debug("Parameter shapes:")
for name, param in model.named_parameters():
logger.debug(f" {name}: {list(param.shape)}")
def log_training_summary(
config: Dict[str, Any],
metrics: Dict[str, float],
elapsed_time: float,
output_dir: str
):
"""Log comprehensive training summary"""
logger = get_logger(__name__)
summary = {
'config': config,
'final_metrics': metrics,
'training_time': elapsed_time,
'timestamp': datetime.now().isoformat()
}
# Save to file
summary_file = os.path.join(output_dir, 'training_summary.json')
with open(summary_file, 'w') as f:
json.dump(summary, f, indent=2)
# Log to console
logger.info("=" * 50)
logger.info("TRAINING SUMMARY")
logger.info("=" * 50)
logger.info(f"Training time: {elapsed_time:.2f} seconds")
logger.info("Final metrics:")
for metric, value in metrics.items():
logger.info(f" {metric}: {value:.4f}")
logger.info(f"Summary saved to: {summary_file}")
logger.info("=" * 50)