import os import torch import torch.backends.cudnn as cudnn import torch.distributed as dist from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.data import DataLoader, DistributedSampler import logging from typing import Optional, Any, List import random import numpy as np from .config_loader import get_config # Import Ascend support utilities try: from .ascend_support import ( check_ascend_availability, get_ascend_backend, setup_ascend_device, clear_ascend_cache ) ASCEND_SUPPORT_AVAILABLE = True except ImportError: ASCEND_SUPPORT_AVAILABLE = False logger = logging.getLogger(__name__) def setup_distributed(config=None): """Setup distributed training environment with proper error handling.""" try: # Check if we're in a distributed environment world_size = int(os.environ.get('WORLD_SIZE', 1)) rank = int(os.environ.get('RANK', 0)) local_rank = int(os.environ.get('LOCAL_RANK', 0)) if world_size == 1: logger.info("Single GPU/CPU training mode") return # Determine backend based on hardware availability # First check for Ascend NPU support if ASCEND_SUPPORT_AVAILABLE and check_ascend_availability(): backend = 'hccl' try: setup_ascend_device(local_rank) logger.info("Using Ascend NPU with HCCL backend") except Exception as e: logger.warning(f"Failed to setup Ascend device: {e}, falling back") backend = 'gloo' elif torch.cuda.is_available(): backend = 'nccl' # Set CUDA device torch.cuda.set_device(local_rank) else: backend = 'gloo' # Override backend from config if provided if config is not None: backend = config.get('distributed.backend', backend) # Initialize process group init_method = os.environ.get('MASTER_ADDR', None) if init_method: init_method = f"tcp://{os.environ['MASTER_ADDR']}:{os.environ.get('MASTER_PORT', '29500')}" else: init_method = 'env://' dist.init_process_group( backend=backend, init_method=init_method, world_size=world_size, rank=rank ) logger.info(f"Distributed training initialized: backend={backend}, world_size={world_size}, rank={rank}, local_rank={local_rank}") except Exception as e: logger.warning(f"Failed to initialize distributed training: {e}") logger.info("Falling back to single GPU/CPU training") # Set environment variables to indicate single GPU mode os.environ['WORLD_SIZE'] = '1' os.environ['RANK'] = '0' os.environ['LOCAL_RANK'] = '0' def cleanup_distributed(): """Cleanup distributed training environment with error handling.""" try: if dist.is_initialized(): dist.destroy_process_group() logger.info("Distributed process group destroyed.") except Exception as e: logger.warning(f"Error during distributed cleanup: {e}") # Continue execution even if cleanup fails def is_main_process() -> bool: """Check if this is the main process""" if not dist.is_initialized(): return True return dist.get_rank() == 0 def get_rank() -> int: """Get current process rank""" if not dist.is_initialized(): return 0 return dist.get_rank() def get_world_size() -> int: """Get world size""" if not dist.is_initialized(): return 1 return dist.get_world_size() def get_local_rank() -> int: """Get local rank (GPU index on current node)""" if not dist.is_initialized(): return 0 return int(os.environ.get('LOCAL_RANK', 0)) def synchronize(): """Synchronize all processes""" if dist.is_initialized(): dist.barrier() def all_reduce(tensor: torch.Tensor, op=dist.ReduceOp.SUM) -> torch.Tensor: """ All-reduce a tensor across all processes Args: tensor: Tensor to reduce op: Reduction operation Returns: Reduced tensor """ if not dist.is_initialized(): return tensor dist.all_reduce(tensor, op=op) return tensor def all_gather(tensor: torch.Tensor) -> List[torch.Tensor]: """ All-gather tensors from all processes Args: tensor: Tensor to gather Returns: List of tensors from all processes """ if not dist.is_initialized(): return [tensor] world_size = get_world_size() tensors = [torch.empty_like(tensor) for _ in range(world_size)] dist.all_gather(tensors, tensor) return tensors def broadcast(tensor: torch.Tensor, src: int = 0) -> torch.Tensor: """ Broadcast tensor from source to all processes Args: tensor: Tensor to broadcast src: Source rank Returns: Broadcasted tensor """ if not dist.is_initialized(): return tensor dist.broadcast(tensor, src=src) return tensor def reduce_dict(input_dict: dict, average: bool = True) -> dict: """ Reduce a dictionary of values across all processes Args: input_dict: Dictionary to reduce average: Whether to average the values Returns: Reduced dictionary """ if not dist.is_initialized(): return input_dict world_size = get_world_size() if world_size < 2: return input_dict with torch.no_grad(): names = [] values = [] for k in sorted(input_dict.keys()): names.append(k) values.append(input_dict[k]) values = torch.stack(values, dim=0) dist.all_reduce(values) if average: values /= world_size reduced_dict = {k: v.item() for k, v in zip(names, values)} return reduced_dict def setup_model_for_distributed( model: torch.nn.Module, device_ids: Optional[List[int]] = None, output_device: Optional[int] = None, find_unused_parameters: bool = True, broadcast_buffers: bool = True ) -> torch.nn.Module: """ Wrap model for distributed training Args: model: Model to wrap device_ids: GPU indices to use output_device: Device where outputs are gathered find_unused_parameters: Whether to find unused parameters broadcast_buffers: Whether to broadcast buffers Returns: Distributed model """ if not dist.is_initialized(): return model if device_ids is None: device_ids = [get_local_rank()] if output_device is None: output_device = device_ids[0] model = DDP( model, device_ids=device_ids, output_device=output_device, find_unused_parameters=find_unused_parameters, broadcast_buffers=broadcast_buffers ) logger.info(f"Wrapped model for distributed training on devices {device_ids}") return model def create_distributed_dataloader( dataset, batch_size: int, shuffle: bool = True, num_workers: int = 0, pin_memory: bool = True, drop_last: bool = False, seed: int = 42 ) -> DataLoader: """ Create a distributed dataloader Args: dataset: Dataset to load batch_size: Batch size per GPU shuffle: Whether to shuffle num_workers: Number of data loading workers pin_memory: Whether to pin memory drop_last: Whether to drop last incomplete batch seed: Random seed for shuffling Returns: Distributed dataloader """ if dist.is_initialized(): sampler = DistributedSampler( dataset, num_replicas=get_world_size(), rank=get_rank(), shuffle=shuffle, seed=seed, drop_last=drop_last ) shuffle = False # Sampler handles shuffling else: sampler = None dataloader = DataLoader( dataset, batch_size=batch_size, shuffle=shuffle, sampler=sampler, num_workers=num_workers, pin_memory=pin_memory, drop_last=drop_last ) return dataloader def set_random_seed(seed: int, deterministic: bool = False): """ Set random seed for reproducibility Args: seed: Random seed deterministic: Whether to use deterministic algorithms """ random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) if deterministic and torch.cuda.is_available(): # Configure cuDNN determinism settings cudnn.deterministic = True cudnn.benchmark = False logger.info(f"Set random seed to {seed} (deterministic={deterministic})") def print_distributed_info(): """Print distributed training information""" if not dist.is_initialized(): logger.info("Distributed training not initialized") return logger.info(f"Distributed Info:") logger.info(f" Backend: {dist.get_backend()}") logger.info(f" World size: {get_world_size()}") logger.info(f" Rank: {get_rank()}") logger.info(f" Local rank: {get_local_rank()}") logger.info(f" Master addr: {os.environ.get('MASTER_ADDR', 'N/A')}") logger.info(f" Master port: {os.environ.get('MASTER_PORT', 'N/A')}") class DistributedMetricAggregator: """Helper class to aggregate metrics across distributed processes""" def __init__(self): self.metrics = {} def update(self, **kwargs): """Update metrics""" for k, v in kwargs.items(): if isinstance(v, torch.Tensor): v = v.detach() else: v = torch.tensor(v, dtype=torch.float32) if torch.cuda.is_available(): v = v.cuda() self.metrics[k] = v def aggregate(self, average: bool = True) -> dict: """Aggregate metrics across all processes""" if not dist.is_initialized(): return {k: v.item() if isinstance(v, torch.Tensor) else v for k, v in self.metrics.items()} aggregated = {} world_size = get_world_size() for k, v in self.metrics.items(): # All-reduce all_reduce(v) # Average if requested if average: v = v / world_size aggregated[k] = v.item() return aggregated def reset(self): """Reset metrics""" self.metrics = {} def save_on_master(*args, **kwargs): """Decorator to only save on master process""" def decorator(func): def wrapper(*args, **kwargs): if is_main_process(): return func(*args, **kwargs) else: # Return None on non-master processes return None return wrapper return decorator def log_on_master(logger: logging.Logger): """Get a logger that only logs on master process""" class MasterOnlyLogger: def __init__(self, logger): self.logger = logger def __getattr__(self, name): if is_main_process(): return getattr(self.logger, name) else: return lambda *args, **kwargs: None return MasterOnlyLogger(logger)