init commit
This commit is contained in:
425
utils/distributed.py
Normal file
425
utils/distributed.py
Normal file
@@ -0,0 +1,425 @@
|
||||
import os
|
||||
import torch
|
||||
import torch.backends.cudnn as cudnn
|
||||
import torch.distributed as dist
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.utils.data import DataLoader, DistributedSampler
|
||||
import logging
|
||||
from typing import Optional, Any, List
|
||||
import random
|
||||
import numpy as np
|
||||
from .config_loader import get_config
|
||||
|
||||
# Import Ascend support utilities
|
||||
try:
|
||||
from .ascend_support import (
|
||||
check_ascend_availability,
|
||||
get_ascend_backend,
|
||||
setup_ascend_device,
|
||||
clear_ascend_cache
|
||||
)
|
||||
ASCEND_SUPPORT_AVAILABLE = True
|
||||
except ImportError:
|
||||
ASCEND_SUPPORT_AVAILABLE = False
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def setup_distributed(config=None):
|
||||
"""Setup distributed training environment with proper error handling."""
|
||||
try:
|
||||
# Check if we're in a distributed environment
|
||||
world_size = int(os.environ.get('WORLD_SIZE', 1))
|
||||
rank = int(os.environ.get('RANK', 0))
|
||||
local_rank = int(os.environ.get('LOCAL_RANK', 0))
|
||||
|
||||
if world_size == 1:
|
||||
logger.info("Single GPU/CPU training mode")
|
||||
return
|
||||
|
||||
# Determine backend based on hardware availability
|
||||
# First check for Ascend NPU support
|
||||
if ASCEND_SUPPORT_AVAILABLE and check_ascend_availability():
|
||||
backend = 'hccl'
|
||||
try:
|
||||
setup_ascend_device(local_rank)
|
||||
logger.info("Using Ascend NPU with HCCL backend")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to setup Ascend device: {e}, falling back")
|
||||
backend = 'gloo'
|
||||
elif torch.cuda.is_available():
|
||||
backend = 'nccl'
|
||||
# Set CUDA device
|
||||
torch.cuda.set_device(local_rank)
|
||||
else:
|
||||
backend = 'gloo'
|
||||
|
||||
# Override backend from config if provided
|
||||
if config is not None:
|
||||
backend = config.get('distributed.backend', backend)
|
||||
|
||||
# Initialize process group
|
||||
init_method = os.environ.get('MASTER_ADDR', None)
|
||||
if init_method:
|
||||
init_method = f"tcp://{os.environ['MASTER_ADDR']}:{os.environ.get('MASTER_PORT', '29500')}"
|
||||
else:
|
||||
init_method = 'env://'
|
||||
|
||||
dist.init_process_group(
|
||||
backend=backend,
|
||||
init_method=init_method,
|
||||
world_size=world_size,
|
||||
rank=rank
|
||||
)
|
||||
|
||||
logger.info(f"Distributed training initialized: backend={backend}, world_size={world_size}, rank={rank}, local_rank={local_rank}")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to initialize distributed training: {e}")
|
||||
logger.info("Falling back to single GPU/CPU training")
|
||||
# Set environment variables to indicate single GPU mode
|
||||
os.environ['WORLD_SIZE'] = '1'
|
||||
os.environ['RANK'] = '0'
|
||||
os.environ['LOCAL_RANK'] = '0'
|
||||
|
||||
|
||||
def cleanup_distributed():
|
||||
"""Cleanup distributed training environment with error handling."""
|
||||
try:
|
||||
if dist.is_initialized():
|
||||
dist.destroy_process_group()
|
||||
logger.info("Distributed process group destroyed.")
|
||||
except Exception as e:
|
||||
logger.warning(f"Error during distributed cleanup: {e}")
|
||||
# Continue execution even if cleanup fails
|
||||
|
||||
|
||||
def is_main_process() -> bool:
|
||||
"""Check if this is the main process"""
|
||||
if not dist.is_initialized():
|
||||
return True
|
||||
return dist.get_rank() == 0
|
||||
|
||||
|
||||
def get_rank() -> int:
|
||||
"""Get current process rank"""
|
||||
if not dist.is_initialized():
|
||||
return 0
|
||||
return dist.get_rank()
|
||||
|
||||
|
||||
def get_world_size() -> int:
|
||||
"""Get world size"""
|
||||
if not dist.is_initialized():
|
||||
return 1
|
||||
return dist.get_world_size()
|
||||
|
||||
|
||||
def get_local_rank() -> int:
|
||||
"""Get local rank (GPU index on current node)"""
|
||||
if not dist.is_initialized():
|
||||
return 0
|
||||
return int(os.environ.get('LOCAL_RANK', 0))
|
||||
|
||||
|
||||
def synchronize():
|
||||
"""Synchronize all processes"""
|
||||
if dist.is_initialized():
|
||||
dist.barrier()
|
||||
|
||||
|
||||
def all_reduce(tensor: torch.Tensor, op=dist.ReduceOp.SUM) -> torch.Tensor:
|
||||
"""
|
||||
All-reduce a tensor across all processes
|
||||
|
||||
Args:
|
||||
tensor: Tensor to reduce
|
||||
op: Reduction operation
|
||||
|
||||
Returns:
|
||||
Reduced tensor
|
||||
"""
|
||||
if not dist.is_initialized():
|
||||
return tensor
|
||||
|
||||
dist.all_reduce(tensor, op=op)
|
||||
return tensor
|
||||
|
||||
|
||||
def all_gather(tensor: torch.Tensor) -> List[torch.Tensor]:
|
||||
"""
|
||||
All-gather tensors from all processes
|
||||
|
||||
Args:
|
||||
tensor: Tensor to gather
|
||||
|
||||
Returns:
|
||||
List of tensors from all processes
|
||||
"""
|
||||
if not dist.is_initialized():
|
||||
return [tensor]
|
||||
|
||||
world_size = get_world_size()
|
||||
tensors = [torch.empty_like(tensor) for _ in range(world_size)]
|
||||
dist.all_gather(tensors, tensor)
|
||||
return tensors
|
||||
|
||||
|
||||
def broadcast(tensor: torch.Tensor, src: int = 0) -> torch.Tensor:
|
||||
"""
|
||||
Broadcast tensor from source to all processes
|
||||
|
||||
Args:
|
||||
tensor: Tensor to broadcast
|
||||
src: Source rank
|
||||
|
||||
Returns:
|
||||
Broadcasted tensor
|
||||
"""
|
||||
if not dist.is_initialized():
|
||||
return tensor
|
||||
|
||||
dist.broadcast(tensor, src=src)
|
||||
return tensor
|
||||
|
||||
|
||||
def reduce_dict(input_dict: dict, average: bool = True) -> dict:
|
||||
"""
|
||||
Reduce a dictionary of values across all processes
|
||||
|
||||
Args:
|
||||
input_dict: Dictionary to reduce
|
||||
average: Whether to average the values
|
||||
|
||||
Returns:
|
||||
Reduced dictionary
|
||||
"""
|
||||
if not dist.is_initialized():
|
||||
return input_dict
|
||||
|
||||
world_size = get_world_size()
|
||||
if world_size < 2:
|
||||
return input_dict
|
||||
|
||||
with torch.no_grad():
|
||||
names = []
|
||||
values = []
|
||||
|
||||
for k in sorted(input_dict.keys()):
|
||||
names.append(k)
|
||||
values.append(input_dict[k])
|
||||
|
||||
values = torch.stack(values, dim=0)
|
||||
dist.all_reduce(values)
|
||||
|
||||
if average:
|
||||
values /= world_size
|
||||
|
||||
reduced_dict = {k: v.item() for k, v in zip(names, values)}
|
||||
|
||||
return reduced_dict
|
||||
|
||||
|
||||
def setup_model_for_distributed(
|
||||
model: torch.nn.Module,
|
||||
device_ids: Optional[List[int]] = None,
|
||||
output_device: Optional[int] = None,
|
||||
find_unused_parameters: bool = True,
|
||||
broadcast_buffers: bool = True
|
||||
) -> torch.nn.Module:
|
||||
"""
|
||||
Wrap model for distributed training
|
||||
|
||||
Args:
|
||||
model: Model to wrap
|
||||
device_ids: GPU indices to use
|
||||
output_device: Device where outputs are gathered
|
||||
find_unused_parameters: Whether to find unused parameters
|
||||
broadcast_buffers: Whether to broadcast buffers
|
||||
|
||||
Returns:
|
||||
Distributed model
|
||||
"""
|
||||
if not dist.is_initialized():
|
||||
return model
|
||||
|
||||
if device_ids is None:
|
||||
device_ids = [get_local_rank()]
|
||||
|
||||
if output_device is None:
|
||||
output_device = device_ids[0]
|
||||
|
||||
model = DDP(
|
||||
model,
|
||||
device_ids=device_ids,
|
||||
output_device=output_device,
|
||||
find_unused_parameters=find_unused_parameters,
|
||||
broadcast_buffers=broadcast_buffers
|
||||
)
|
||||
|
||||
logger.info(f"Wrapped model for distributed training on devices {device_ids}")
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def create_distributed_dataloader(
|
||||
dataset,
|
||||
batch_size: int,
|
||||
shuffle: bool = True,
|
||||
num_workers: int = 0,
|
||||
pin_memory: bool = True,
|
||||
drop_last: bool = False,
|
||||
seed: int = 42
|
||||
) -> DataLoader:
|
||||
"""
|
||||
Create a distributed dataloader
|
||||
|
||||
Args:
|
||||
dataset: Dataset to load
|
||||
batch_size: Batch size per GPU
|
||||
shuffle: Whether to shuffle
|
||||
num_workers: Number of data loading workers
|
||||
pin_memory: Whether to pin memory
|
||||
drop_last: Whether to drop last incomplete batch
|
||||
seed: Random seed for shuffling
|
||||
|
||||
Returns:
|
||||
Distributed dataloader
|
||||
"""
|
||||
if dist.is_initialized():
|
||||
sampler = DistributedSampler(
|
||||
dataset,
|
||||
num_replicas=get_world_size(),
|
||||
rank=get_rank(),
|
||||
shuffle=shuffle,
|
||||
seed=seed,
|
||||
drop_last=drop_last
|
||||
)
|
||||
shuffle = False # Sampler handles shuffling
|
||||
else:
|
||||
sampler = None
|
||||
|
||||
dataloader = DataLoader(
|
||||
dataset,
|
||||
batch_size=batch_size,
|
||||
shuffle=shuffle,
|
||||
sampler=sampler,
|
||||
num_workers=num_workers,
|
||||
pin_memory=pin_memory,
|
||||
drop_last=drop_last
|
||||
)
|
||||
|
||||
return dataloader
|
||||
|
||||
|
||||
def set_random_seed(seed: int, deterministic: bool = False):
|
||||
"""
|
||||
Set random seed for reproducibility
|
||||
|
||||
Args:
|
||||
seed: Random seed
|
||||
deterministic: Whether to use deterministic algorithms
|
||||
"""
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
|
||||
if deterministic and torch.cuda.is_available():
|
||||
# Configure cuDNN determinism settings
|
||||
cudnn.deterministic = True
|
||||
cudnn.benchmark = False
|
||||
|
||||
logger.info(f"Set random seed to {seed} (deterministic={deterministic})")
|
||||
|
||||
|
||||
def print_distributed_info():
|
||||
"""Print distributed training information"""
|
||||
if not dist.is_initialized():
|
||||
logger.info("Distributed training not initialized")
|
||||
return
|
||||
|
||||
logger.info(f"Distributed Info:")
|
||||
logger.info(f" Backend: {dist.get_backend()}")
|
||||
logger.info(f" World size: {get_world_size()}")
|
||||
logger.info(f" Rank: {get_rank()}")
|
||||
logger.info(f" Local rank: {get_local_rank()}")
|
||||
logger.info(f" Master addr: {os.environ.get('MASTER_ADDR', 'N/A')}")
|
||||
logger.info(f" Master port: {os.environ.get('MASTER_PORT', 'N/A')}")
|
||||
|
||||
|
||||
class DistributedMetricAggregator:
|
||||
"""Helper class to aggregate metrics across distributed processes"""
|
||||
|
||||
def __init__(self):
|
||||
self.metrics = {}
|
||||
|
||||
def update(self, **kwargs):
|
||||
"""Update metrics"""
|
||||
for k, v in kwargs.items():
|
||||
if isinstance(v, torch.Tensor):
|
||||
v = v.detach()
|
||||
else:
|
||||
v = torch.tensor(v, dtype=torch.float32)
|
||||
|
||||
if torch.cuda.is_available():
|
||||
v = v.cuda()
|
||||
|
||||
self.metrics[k] = v
|
||||
|
||||
def aggregate(self, average: bool = True) -> dict:
|
||||
"""Aggregate metrics across all processes"""
|
||||
if not dist.is_initialized():
|
||||
return {k: v.item() if isinstance(v, torch.Tensor) else v
|
||||
for k, v in self.metrics.items()}
|
||||
|
||||
aggregated = {}
|
||||
world_size = get_world_size()
|
||||
|
||||
for k, v in self.metrics.items():
|
||||
# All-reduce
|
||||
all_reduce(v)
|
||||
|
||||
# Average if requested
|
||||
if average:
|
||||
v = v / world_size
|
||||
|
||||
aggregated[k] = v.item()
|
||||
|
||||
return aggregated
|
||||
|
||||
def reset(self):
|
||||
"""Reset metrics"""
|
||||
self.metrics = {}
|
||||
|
||||
|
||||
def save_on_master(*args, **kwargs):
|
||||
"""Decorator to only save on master process"""
|
||||
|
||||
def decorator(func):
|
||||
def wrapper(*args, **kwargs):
|
||||
if is_main_process():
|
||||
return func(*args, **kwargs)
|
||||
else:
|
||||
# Return None on non-master processes
|
||||
return None
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def log_on_master(logger: logging.Logger):
|
||||
"""Get a logger that only logs on master process"""
|
||||
|
||||
class MasterOnlyLogger:
|
||||
def __init__(self, logger):
|
||||
self.logger = logger
|
||||
|
||||
def __getattr__(self, name):
|
||||
if is_main_process():
|
||||
return getattr(self.logger, name)
|
||||
else:
|
||||
return lambda *args, **kwargs: None
|
||||
|
||||
return MasterOnlyLogger(logger)
|
||||
Reference in New Issue
Block a user