bge_finetune/utils/distributed.py
2025-07-22 16:55:25 +08:00

426 lines
11 KiB
Python

import os
import torch
import torch.backends.cudnn as cudnn
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader, DistributedSampler
import logging
from typing import Optional, Any, List
import random
import numpy as np
from .config_loader import get_config
# Import Ascend support utilities
try:
from .ascend_support import (
check_ascend_availability,
get_ascend_backend,
setup_ascend_device,
clear_ascend_cache
)
ASCEND_SUPPORT_AVAILABLE = True
except ImportError:
ASCEND_SUPPORT_AVAILABLE = False
logger = logging.getLogger(__name__)
def setup_distributed(config=None):
"""Setup distributed training environment with proper error handling."""
try:
# Check if we're in a distributed environment
world_size = int(os.environ.get('WORLD_SIZE', 1))
rank = int(os.environ.get('RANK', 0))
local_rank = int(os.environ.get('LOCAL_RANK', 0))
if world_size == 1:
logger.info("Single GPU/CPU training mode")
return
# Determine backend based on hardware availability
# First check for Ascend NPU support
if ASCEND_SUPPORT_AVAILABLE and check_ascend_availability():
backend = 'hccl'
try:
setup_ascend_device(local_rank)
logger.info("Using Ascend NPU with HCCL backend")
except Exception as e:
logger.warning(f"Failed to setup Ascend device: {e}, falling back")
backend = 'gloo'
elif torch.cuda.is_available():
backend = 'nccl'
# Set CUDA device
torch.cuda.set_device(local_rank)
else:
backend = 'gloo'
# Override backend from config if provided
if config is not None:
backend = config.get('distributed.backend', backend)
# Initialize process group
init_method = os.environ.get('MASTER_ADDR', None)
if init_method:
init_method = f"tcp://{os.environ['MASTER_ADDR']}:{os.environ.get('MASTER_PORT', '29500')}"
else:
init_method = 'env://'
dist.init_process_group(
backend=backend,
init_method=init_method,
world_size=world_size,
rank=rank
)
logger.info(f"Distributed training initialized: backend={backend}, world_size={world_size}, rank={rank}, local_rank={local_rank}")
except Exception as e:
logger.warning(f"Failed to initialize distributed training: {e}")
logger.info("Falling back to single GPU/CPU training")
# Set environment variables to indicate single GPU mode
os.environ['WORLD_SIZE'] = '1'
os.environ['RANK'] = '0'
os.environ['LOCAL_RANK'] = '0'
def cleanup_distributed():
"""Cleanup distributed training environment with error handling."""
try:
if dist.is_initialized():
dist.destroy_process_group()
logger.info("Distributed process group destroyed.")
except Exception as e:
logger.warning(f"Error during distributed cleanup: {e}")
# Continue execution even if cleanup fails
def is_main_process() -> bool:
"""Check if this is the main process"""
if not dist.is_initialized():
return True
return dist.get_rank() == 0
def get_rank() -> int:
"""Get current process rank"""
if not dist.is_initialized():
return 0
return dist.get_rank()
def get_world_size() -> int:
"""Get world size"""
if not dist.is_initialized():
return 1
return dist.get_world_size()
def get_local_rank() -> int:
"""Get local rank (GPU index on current node)"""
if not dist.is_initialized():
return 0
return int(os.environ.get('LOCAL_RANK', 0))
def synchronize():
"""Synchronize all processes"""
if dist.is_initialized():
dist.barrier()
def all_reduce(tensor: torch.Tensor, op=dist.ReduceOp.SUM) -> torch.Tensor:
"""
All-reduce a tensor across all processes
Args:
tensor: Tensor to reduce
op: Reduction operation
Returns:
Reduced tensor
"""
if not dist.is_initialized():
return tensor
dist.all_reduce(tensor, op=op)
return tensor
def all_gather(tensor: torch.Tensor) -> List[torch.Tensor]:
"""
All-gather tensors from all processes
Args:
tensor: Tensor to gather
Returns:
List of tensors from all processes
"""
if not dist.is_initialized():
return [tensor]
world_size = get_world_size()
tensors = [torch.empty_like(tensor) for _ in range(world_size)]
dist.all_gather(tensors, tensor)
return tensors
def broadcast(tensor: torch.Tensor, src: int = 0) -> torch.Tensor:
"""
Broadcast tensor from source to all processes
Args:
tensor: Tensor to broadcast
src: Source rank
Returns:
Broadcasted tensor
"""
if not dist.is_initialized():
return tensor
dist.broadcast(tensor, src=src)
return tensor
def reduce_dict(input_dict: dict, average: bool = True) -> dict:
"""
Reduce a dictionary of values across all processes
Args:
input_dict: Dictionary to reduce
average: Whether to average the values
Returns:
Reduced dictionary
"""
if not dist.is_initialized():
return input_dict
world_size = get_world_size()
if world_size < 2:
return input_dict
with torch.no_grad():
names = []
values = []
for k in sorted(input_dict.keys()):
names.append(k)
values.append(input_dict[k])
values = torch.stack(values, dim=0)
dist.all_reduce(values)
if average:
values /= world_size
reduced_dict = {k: v.item() for k, v in zip(names, values)}
return reduced_dict
def setup_model_for_distributed(
model: torch.nn.Module,
device_ids: Optional[List[int]] = None,
output_device: Optional[int] = None,
find_unused_parameters: bool = True,
broadcast_buffers: bool = True
) -> torch.nn.Module:
"""
Wrap model for distributed training
Args:
model: Model to wrap
device_ids: GPU indices to use
output_device: Device where outputs are gathered
find_unused_parameters: Whether to find unused parameters
broadcast_buffers: Whether to broadcast buffers
Returns:
Distributed model
"""
if not dist.is_initialized():
return model
if device_ids is None:
device_ids = [get_local_rank()]
if output_device is None:
output_device = device_ids[0]
model = DDP(
model,
device_ids=device_ids,
output_device=output_device,
find_unused_parameters=find_unused_parameters,
broadcast_buffers=broadcast_buffers
)
logger.info(f"Wrapped model for distributed training on devices {device_ids}")
return model
def create_distributed_dataloader(
dataset,
batch_size: int,
shuffle: bool = True,
num_workers: int = 0,
pin_memory: bool = True,
drop_last: bool = False,
seed: int = 42
) -> DataLoader:
"""
Create a distributed dataloader
Args:
dataset: Dataset to load
batch_size: Batch size per GPU
shuffle: Whether to shuffle
num_workers: Number of data loading workers
pin_memory: Whether to pin memory
drop_last: Whether to drop last incomplete batch
seed: Random seed for shuffling
Returns:
Distributed dataloader
"""
if dist.is_initialized():
sampler = DistributedSampler(
dataset,
num_replicas=get_world_size(),
rank=get_rank(),
shuffle=shuffle,
seed=seed,
drop_last=drop_last
)
shuffle = False # Sampler handles shuffling
else:
sampler = None
dataloader = DataLoader(
dataset,
batch_size=batch_size,
shuffle=shuffle,
sampler=sampler,
num_workers=num_workers,
pin_memory=pin_memory,
drop_last=drop_last
)
return dataloader
def set_random_seed(seed: int, deterministic: bool = False):
"""
Set random seed for reproducibility
Args:
seed: Random seed
deterministic: Whether to use deterministic algorithms
"""
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
if deterministic and torch.cuda.is_available():
# Configure cuDNN determinism settings
cudnn.deterministic = True
cudnn.benchmark = False
logger.info(f"Set random seed to {seed} (deterministic={deterministic})")
def print_distributed_info():
"""Print distributed training information"""
if not dist.is_initialized():
logger.info("Distributed training not initialized")
return
logger.info(f"Distributed Info:")
logger.info(f" Backend: {dist.get_backend()}")
logger.info(f" World size: {get_world_size()}")
logger.info(f" Rank: {get_rank()}")
logger.info(f" Local rank: {get_local_rank()}")
logger.info(f" Master addr: {os.environ.get('MASTER_ADDR', 'N/A')}")
logger.info(f" Master port: {os.environ.get('MASTER_PORT', 'N/A')}")
class DistributedMetricAggregator:
"""Helper class to aggregate metrics across distributed processes"""
def __init__(self):
self.metrics = {}
def update(self, **kwargs):
"""Update metrics"""
for k, v in kwargs.items():
if isinstance(v, torch.Tensor):
v = v.detach()
else:
v = torch.tensor(v, dtype=torch.float32)
if torch.cuda.is_available():
v = v.cuda()
self.metrics[k] = v
def aggregate(self, average: bool = True) -> dict:
"""Aggregate metrics across all processes"""
if not dist.is_initialized():
return {k: v.item() if isinstance(v, torch.Tensor) else v
for k, v in self.metrics.items()}
aggregated = {}
world_size = get_world_size()
for k, v in self.metrics.items():
# All-reduce
all_reduce(v)
# Average if requested
if average:
v = v / world_size
aggregated[k] = v.item()
return aggregated
def reset(self):
"""Reset metrics"""
self.metrics = {}
def save_on_master(*args, **kwargs):
"""Decorator to only save on master process"""
def decorator(func):
def wrapper(*args, **kwargs):
if is_main_process():
return func(*args, **kwargs)
else:
# Return None on non-master processes
return None
return wrapper
return decorator
def log_on_master(logger: logging.Logger):
"""Get a logger that only logs on master process"""
class MasterOnlyLogger:
def __init__(self, logger):
self.logger = logger
def __getattr__(self, name):
if is_main_process():
return getattr(self.logger, name)
else:
return lambda *args, **kwargs: None
return MasterOnlyLogger(logger)