426 lines
11 KiB
Python
426 lines
11 KiB
Python
import os
|
|
import torch
|
|
import torch.backends.cudnn as cudnn
|
|
import torch.distributed as dist
|
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
|
from torch.utils.data import DataLoader, DistributedSampler
|
|
import logging
|
|
from typing import Optional, Any, List
|
|
import random
|
|
import numpy as np
|
|
from .config_loader import get_config
|
|
|
|
# Import Ascend support utilities
|
|
try:
|
|
from .ascend_support import (
|
|
check_ascend_availability,
|
|
get_ascend_backend,
|
|
setup_ascend_device,
|
|
clear_ascend_cache
|
|
)
|
|
ASCEND_SUPPORT_AVAILABLE = True
|
|
except ImportError:
|
|
ASCEND_SUPPORT_AVAILABLE = False
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def setup_distributed(config=None):
|
|
"""Setup distributed training environment with proper error handling."""
|
|
try:
|
|
# Check if we're in a distributed environment
|
|
world_size = int(os.environ.get('WORLD_SIZE', 1))
|
|
rank = int(os.environ.get('RANK', 0))
|
|
local_rank = int(os.environ.get('LOCAL_RANK', 0))
|
|
|
|
if world_size == 1:
|
|
logger.info("Single GPU/CPU training mode")
|
|
return
|
|
|
|
# Determine backend based on hardware availability
|
|
# First check for Ascend NPU support
|
|
if ASCEND_SUPPORT_AVAILABLE and check_ascend_availability():
|
|
backend = 'hccl'
|
|
try:
|
|
setup_ascend_device(local_rank)
|
|
logger.info("Using Ascend NPU with HCCL backend")
|
|
except Exception as e:
|
|
logger.warning(f"Failed to setup Ascend device: {e}, falling back")
|
|
backend = 'gloo'
|
|
elif torch.cuda.is_available():
|
|
backend = 'nccl'
|
|
# Set CUDA device
|
|
torch.cuda.set_device(local_rank)
|
|
else:
|
|
backend = 'gloo'
|
|
|
|
# Override backend from config if provided
|
|
if config is not None:
|
|
backend = config.get('distributed.backend', backend)
|
|
|
|
# Initialize process group
|
|
init_method = os.environ.get('MASTER_ADDR', None)
|
|
if init_method:
|
|
init_method = f"tcp://{os.environ['MASTER_ADDR']}:{os.environ.get('MASTER_PORT', '29500')}"
|
|
else:
|
|
init_method = 'env://'
|
|
|
|
dist.init_process_group(
|
|
backend=backend,
|
|
init_method=init_method,
|
|
world_size=world_size,
|
|
rank=rank
|
|
)
|
|
|
|
logger.info(f"Distributed training initialized: backend={backend}, world_size={world_size}, rank={rank}, local_rank={local_rank}")
|
|
|
|
except Exception as e:
|
|
logger.warning(f"Failed to initialize distributed training: {e}")
|
|
logger.info("Falling back to single GPU/CPU training")
|
|
# Set environment variables to indicate single GPU mode
|
|
os.environ['WORLD_SIZE'] = '1'
|
|
os.environ['RANK'] = '0'
|
|
os.environ['LOCAL_RANK'] = '0'
|
|
|
|
|
|
def cleanup_distributed():
|
|
"""Cleanup distributed training environment with error handling."""
|
|
try:
|
|
if dist.is_initialized():
|
|
dist.destroy_process_group()
|
|
logger.info("Distributed process group destroyed.")
|
|
except Exception as e:
|
|
logger.warning(f"Error during distributed cleanup: {e}")
|
|
# Continue execution even if cleanup fails
|
|
|
|
|
|
def is_main_process() -> bool:
|
|
"""Check if this is the main process"""
|
|
if not dist.is_initialized():
|
|
return True
|
|
return dist.get_rank() == 0
|
|
|
|
|
|
def get_rank() -> int:
|
|
"""Get current process rank"""
|
|
if not dist.is_initialized():
|
|
return 0
|
|
return dist.get_rank()
|
|
|
|
|
|
def get_world_size() -> int:
|
|
"""Get world size"""
|
|
if not dist.is_initialized():
|
|
return 1
|
|
return dist.get_world_size()
|
|
|
|
|
|
def get_local_rank() -> int:
|
|
"""Get local rank (GPU index on current node)"""
|
|
if not dist.is_initialized():
|
|
return 0
|
|
return int(os.environ.get('LOCAL_RANK', 0))
|
|
|
|
|
|
def synchronize():
|
|
"""Synchronize all processes"""
|
|
if dist.is_initialized():
|
|
dist.barrier()
|
|
|
|
|
|
def all_reduce(tensor: torch.Tensor, op=dist.ReduceOp.SUM) -> torch.Tensor:
|
|
"""
|
|
All-reduce a tensor across all processes
|
|
|
|
Args:
|
|
tensor: Tensor to reduce
|
|
op: Reduction operation
|
|
|
|
Returns:
|
|
Reduced tensor
|
|
"""
|
|
if not dist.is_initialized():
|
|
return tensor
|
|
|
|
dist.all_reduce(tensor, op=op)
|
|
return tensor
|
|
|
|
|
|
def all_gather(tensor: torch.Tensor) -> List[torch.Tensor]:
|
|
"""
|
|
All-gather tensors from all processes
|
|
|
|
Args:
|
|
tensor: Tensor to gather
|
|
|
|
Returns:
|
|
List of tensors from all processes
|
|
"""
|
|
if not dist.is_initialized():
|
|
return [tensor]
|
|
|
|
world_size = get_world_size()
|
|
tensors = [torch.empty_like(tensor) for _ in range(world_size)]
|
|
dist.all_gather(tensors, tensor)
|
|
return tensors
|
|
|
|
|
|
def broadcast(tensor: torch.Tensor, src: int = 0) -> torch.Tensor:
|
|
"""
|
|
Broadcast tensor from source to all processes
|
|
|
|
Args:
|
|
tensor: Tensor to broadcast
|
|
src: Source rank
|
|
|
|
Returns:
|
|
Broadcasted tensor
|
|
"""
|
|
if not dist.is_initialized():
|
|
return tensor
|
|
|
|
dist.broadcast(tensor, src=src)
|
|
return tensor
|
|
|
|
|
|
def reduce_dict(input_dict: dict, average: bool = True) -> dict:
|
|
"""
|
|
Reduce a dictionary of values across all processes
|
|
|
|
Args:
|
|
input_dict: Dictionary to reduce
|
|
average: Whether to average the values
|
|
|
|
Returns:
|
|
Reduced dictionary
|
|
"""
|
|
if not dist.is_initialized():
|
|
return input_dict
|
|
|
|
world_size = get_world_size()
|
|
if world_size < 2:
|
|
return input_dict
|
|
|
|
with torch.no_grad():
|
|
names = []
|
|
values = []
|
|
|
|
for k in sorted(input_dict.keys()):
|
|
names.append(k)
|
|
values.append(input_dict[k])
|
|
|
|
values = torch.stack(values, dim=0)
|
|
dist.all_reduce(values)
|
|
|
|
if average:
|
|
values /= world_size
|
|
|
|
reduced_dict = {k: v.item() for k, v in zip(names, values)}
|
|
|
|
return reduced_dict
|
|
|
|
|
|
def setup_model_for_distributed(
|
|
model: torch.nn.Module,
|
|
device_ids: Optional[List[int]] = None,
|
|
output_device: Optional[int] = None,
|
|
find_unused_parameters: bool = True,
|
|
broadcast_buffers: bool = True
|
|
) -> torch.nn.Module:
|
|
"""
|
|
Wrap model for distributed training
|
|
|
|
Args:
|
|
model: Model to wrap
|
|
device_ids: GPU indices to use
|
|
output_device: Device where outputs are gathered
|
|
find_unused_parameters: Whether to find unused parameters
|
|
broadcast_buffers: Whether to broadcast buffers
|
|
|
|
Returns:
|
|
Distributed model
|
|
"""
|
|
if not dist.is_initialized():
|
|
return model
|
|
|
|
if device_ids is None:
|
|
device_ids = [get_local_rank()]
|
|
|
|
if output_device is None:
|
|
output_device = device_ids[0]
|
|
|
|
model = DDP(
|
|
model,
|
|
device_ids=device_ids,
|
|
output_device=output_device,
|
|
find_unused_parameters=find_unused_parameters,
|
|
broadcast_buffers=broadcast_buffers
|
|
)
|
|
|
|
logger.info(f"Wrapped model for distributed training on devices {device_ids}")
|
|
|
|
return model
|
|
|
|
|
|
def create_distributed_dataloader(
|
|
dataset,
|
|
batch_size: int,
|
|
shuffle: bool = True,
|
|
num_workers: int = 0,
|
|
pin_memory: bool = True,
|
|
drop_last: bool = False,
|
|
seed: int = 42
|
|
) -> DataLoader:
|
|
"""
|
|
Create a distributed dataloader
|
|
|
|
Args:
|
|
dataset: Dataset to load
|
|
batch_size: Batch size per GPU
|
|
shuffle: Whether to shuffle
|
|
num_workers: Number of data loading workers
|
|
pin_memory: Whether to pin memory
|
|
drop_last: Whether to drop last incomplete batch
|
|
seed: Random seed for shuffling
|
|
|
|
Returns:
|
|
Distributed dataloader
|
|
"""
|
|
if dist.is_initialized():
|
|
sampler = DistributedSampler(
|
|
dataset,
|
|
num_replicas=get_world_size(),
|
|
rank=get_rank(),
|
|
shuffle=shuffle,
|
|
seed=seed,
|
|
drop_last=drop_last
|
|
)
|
|
shuffle = False # Sampler handles shuffling
|
|
else:
|
|
sampler = None
|
|
|
|
dataloader = DataLoader(
|
|
dataset,
|
|
batch_size=batch_size,
|
|
shuffle=shuffle,
|
|
sampler=sampler,
|
|
num_workers=num_workers,
|
|
pin_memory=pin_memory,
|
|
drop_last=drop_last
|
|
)
|
|
|
|
return dataloader
|
|
|
|
|
|
def set_random_seed(seed: int, deterministic: bool = False):
|
|
"""
|
|
Set random seed for reproducibility
|
|
|
|
Args:
|
|
seed: Random seed
|
|
deterministic: Whether to use deterministic algorithms
|
|
"""
|
|
random.seed(seed)
|
|
np.random.seed(seed)
|
|
torch.manual_seed(seed)
|
|
torch.cuda.manual_seed_all(seed)
|
|
|
|
if deterministic and torch.cuda.is_available():
|
|
# Configure cuDNN determinism settings
|
|
cudnn.deterministic = True
|
|
cudnn.benchmark = False
|
|
|
|
logger.info(f"Set random seed to {seed} (deterministic={deterministic})")
|
|
|
|
|
|
def print_distributed_info():
|
|
"""Print distributed training information"""
|
|
if not dist.is_initialized():
|
|
logger.info("Distributed training not initialized")
|
|
return
|
|
|
|
logger.info(f"Distributed Info:")
|
|
logger.info(f" Backend: {dist.get_backend()}")
|
|
logger.info(f" World size: {get_world_size()}")
|
|
logger.info(f" Rank: {get_rank()}")
|
|
logger.info(f" Local rank: {get_local_rank()}")
|
|
logger.info(f" Master addr: {os.environ.get('MASTER_ADDR', 'N/A')}")
|
|
logger.info(f" Master port: {os.environ.get('MASTER_PORT', 'N/A')}")
|
|
|
|
|
|
class DistributedMetricAggregator:
|
|
"""Helper class to aggregate metrics across distributed processes"""
|
|
|
|
def __init__(self):
|
|
self.metrics = {}
|
|
|
|
def update(self, **kwargs):
|
|
"""Update metrics"""
|
|
for k, v in kwargs.items():
|
|
if isinstance(v, torch.Tensor):
|
|
v = v.detach()
|
|
else:
|
|
v = torch.tensor(v, dtype=torch.float32)
|
|
|
|
if torch.cuda.is_available():
|
|
v = v.cuda()
|
|
|
|
self.metrics[k] = v
|
|
|
|
def aggregate(self, average: bool = True) -> dict:
|
|
"""Aggregate metrics across all processes"""
|
|
if not dist.is_initialized():
|
|
return {k: v.item() if isinstance(v, torch.Tensor) else v
|
|
for k, v in self.metrics.items()}
|
|
|
|
aggregated = {}
|
|
world_size = get_world_size()
|
|
|
|
for k, v in self.metrics.items():
|
|
# All-reduce
|
|
all_reduce(v)
|
|
|
|
# Average if requested
|
|
if average:
|
|
v = v / world_size
|
|
|
|
aggregated[k] = v.item()
|
|
|
|
return aggregated
|
|
|
|
def reset(self):
|
|
"""Reset metrics"""
|
|
self.metrics = {}
|
|
|
|
|
|
def save_on_master(*args, **kwargs):
|
|
"""Decorator to only save on master process"""
|
|
|
|
def decorator(func):
|
|
def wrapper(*args, **kwargs):
|
|
if is_main_process():
|
|
return func(*args, **kwargs)
|
|
else:
|
|
# Return None on non-master processes
|
|
return None
|
|
|
|
return wrapper
|
|
|
|
return decorator
|
|
|
|
|
|
def log_on_master(logger: logging.Logger):
|
|
"""Get a logger that only logs on master process"""
|
|
|
|
class MasterOnlyLogger:
|
|
def __init__(self, logger):
|
|
self.logger = logger
|
|
|
|
def __getattr__(self, name):
|
|
if is_main_process():
|
|
return getattr(self.logger, name)
|
|
else:
|
|
return lambda *args, **kwargs: None
|
|
|
|
return MasterOnlyLogger(logger)
|