""" Huawei Ascend NPU support utilities for BGE fine-tuning This module provides compatibility layer for running BGE models on Ascend NPUs. Requires torch_npu to be installed: pip install torch_npu """ import os import logging from typing import Optional, Dict, Any, Tuple import torch logger = logging.getLogger(__name__) # Global flag to track Ascend availability _ASCEND_AVAILABLE = False _ASCEND_CHECK_DONE = False def check_ascend_availability() -> bool: """ Check if Ascend NPU is available Returns: bool: True if Ascend NPU is available and properly configured """ global _ASCEND_AVAILABLE, _ASCEND_CHECK_DONE if _ASCEND_CHECK_DONE: return _ASCEND_AVAILABLE _ASCEND_CHECK_DONE = True try: import torch_npu # Check if NPU is available if hasattr(torch, 'npu') and torch.npu.is_available(): _ASCEND_AVAILABLE = True npu_count = torch.npu.device_count() logger.info(f"Ascend NPU available with {npu_count} devices") # Log NPU information for i in range(npu_count): props = torch.npu.get_device_properties(i) logger.info(f"NPU {i}: {props.name}, Total Memory: {props.total_memory / 1024**3:.2f} GB") else: logger.info("Ascend NPU not available") except ImportError: logger.debug("torch_npu not installed. Ascend NPU support disabled.") except Exception as e: logger.warning(f"Error checking Ascend availability: {e}") return _ASCEND_AVAILABLE def get_ascend_backend() -> str: """ Get the appropriate backend for Ascend NPU Returns: str: 'hccl' for Ascend, 'nccl' for CUDA, 'gloo' for CPU """ if check_ascend_availability(): return 'hccl' elif torch.cuda.is_available(): return 'nccl' else: return 'gloo' def setup_ascend_device(local_rank: int = 0) -> torch.device: """ Setup Ascend NPU device Args: local_rank: Local rank for distributed training Returns: torch.device: Configured device """ if not check_ascend_availability(): raise RuntimeError("Ascend NPU not available") try: # Set NPU device torch.npu.set_device(local_rank) device = torch.device(f'npu:{local_rank}') logger.info(f"Using Ascend NPU device: {device}") return device except Exception as e: logger.error(f"Failed to setup Ascend device: {e}") raise def convert_model_for_ascend(model: torch.nn.Module) -> torch.nn.Module: """ Convert model for Ascend NPU compatibility Args: model: PyTorch model Returns: Model configured for Ascend """ if not check_ascend_availability(): return model try: import torch_npu # Apply Ascend-specific optimizations model = torch_npu.contrib.transfer_to_npu(model) # Enable Ascend graph mode if available if hasattr(torch_npu, 'enable_graph_mode'): torch_npu.enable_graph_mode() logger.info("Enabled Ascend graph mode") return model except Exception as e: logger.warning(f"Failed to apply Ascend optimizations: {e}") return model def optimize_for_ascend(config: Dict[str, Any]) -> Dict[str, Any]: """ Optimize configuration for Ascend NPU Args: config: Training configuration Returns: Optimized configuration """ if not check_ascend_availability(): return config # Create a copy to avoid modifying original config = config.copy() # Ascend-specific optimizations ascend_config = config.get('ascend', {}) # Set backend to hccl for distributed training if ascend_config.get('backend') == 'hccl': config['distributed'] = config.get('distributed', {}) config['distributed']['backend'] = 'hccl' # Enable mixed precision if specified if ascend_config.get('mixed_precision', False): config['optimization'] = config.get('optimization', {}) config['optimization']['fp16'] = True logger.info("Enabled mixed precision for Ascend") # Set visible devices if 'visible_devices' in ascend_config: os.environ['ASCEND_RT_VISIBLE_DEVICES'] = ascend_config['visible_devices'] logger.info(f"Set ASCEND_RT_VISIBLE_DEVICES={ascend_config['visible_devices']}") return config def get_ascend_memory_info(device_id: int = 0) -> Tuple[float, float]: """ Get Ascend NPU memory information Args: device_id: NPU device ID Returns: Tuple of (used_memory_gb, total_memory_gb) """ if not check_ascend_availability(): return 0.0, 0.0 try: import torch_npu # Get memory info used = torch_npu.memory_allocated(device_id) / 1024**3 total = torch_npu.get_device_properties(device_id).total_memory / 1024**3 return used, total except Exception as e: logger.warning(f"Failed to get Ascend memory info: {e}") return 0.0, 0.0 def clear_ascend_cache(): """Clear Ascend NPU memory cache""" if not check_ascend_availability(): return try: torch.npu.empty_cache() logger.debug("Cleared Ascend NPU cache") except Exception as e: logger.warning(f"Failed to clear Ascend cache: {e}") def is_ascend_available() -> bool: """ Simple check for Ascend availability Returns: bool: True if Ascend NPU is available """ return check_ascend_availability() class AscendDataLoader: """ Wrapper for DataLoader with Ascend-specific optimizations """ def __init__(self, dataloader, device): self.dataloader = dataloader self.device = device def __iter__(self): for batch in self.dataloader: # Move batch to NPU device if isinstance(batch, dict): batch = {k: v.to(self.device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()} elif isinstance(batch, (list, tuple)): batch = [v.to(self.device) if isinstance(v, torch.Tensor) else v for v in batch] else: batch = batch.to(self.device) yield batch def __len__(self): return len(self.dataloader) # Compatibility functions def npu_synchronize(): """Synchronize NPU device (compatibility wrapper)""" if check_ascend_availability(): try: torch.npu.synchronize() except Exception as e: logger.warning(f"NPU synchronize failed: {e}") elif torch.cuda.is_available(): torch.cuda.synchronize() def get_device_name(device: torch.device) -> str: """ Get device name with Ascend support Args: device: PyTorch device Returns: str: Device name """ if device.type == 'npu': if check_ascend_availability(): try: props = torch.npu.get_device_properties(device.index or 0) return f"Ascend {props.name}" except: return "Ascend NPU" return "NPU (unavailable)" elif device.type == 'cuda': return torch.cuda.get_device_name(device.index or 0) else: return str(device)