bge_finetune/utils/ascend_support.py
2025-07-22 16:55:25 +08:00

271 lines
7.4 KiB
Python

"""
Huawei Ascend NPU support utilities for BGE fine-tuning
This module provides compatibility layer for running BGE models on Ascend NPUs.
Requires torch_npu to be installed: pip install torch_npu
"""
import os
import logging
from typing import Optional, Dict, Any, Tuple
import torch
logger = logging.getLogger(__name__)
# Global flag to track Ascend availability
_ASCEND_AVAILABLE = False
_ASCEND_CHECK_DONE = False
def check_ascend_availability() -> bool:
"""
Check if Ascend NPU is available
Returns:
bool: True if Ascend NPU is available and properly configured
"""
global _ASCEND_AVAILABLE, _ASCEND_CHECK_DONE
if _ASCEND_CHECK_DONE:
return _ASCEND_AVAILABLE
_ASCEND_CHECK_DONE = True
try:
import torch_npu
# Check if NPU is available
if hasattr(torch, 'npu') and torch.npu.is_available():
_ASCEND_AVAILABLE = True
npu_count = torch.npu.device_count()
logger.info(f"Ascend NPU available with {npu_count} devices")
# Log NPU information
for i in range(npu_count):
props = torch.npu.get_device_properties(i)
logger.info(f"NPU {i}: {props.name}, Total Memory: {props.total_memory / 1024**3:.2f} GB")
else:
logger.info("Ascend NPU not available")
except ImportError:
logger.debug("torch_npu not installed. Ascend NPU support disabled.")
except Exception as e:
logger.warning(f"Error checking Ascend availability: {e}")
return _ASCEND_AVAILABLE
def get_ascend_backend() -> str:
"""
Get the appropriate backend for Ascend NPU
Returns:
str: 'hccl' for Ascend, 'nccl' for CUDA, 'gloo' for CPU
"""
if check_ascend_availability():
return 'hccl'
elif torch.cuda.is_available():
return 'nccl'
else:
return 'gloo'
def setup_ascend_device(local_rank: int = 0) -> torch.device:
"""
Setup Ascend NPU device
Args:
local_rank: Local rank for distributed training
Returns:
torch.device: Configured device
"""
if not check_ascend_availability():
raise RuntimeError("Ascend NPU not available")
try:
# Set NPU device
torch.npu.set_device(local_rank)
device = torch.device(f'npu:{local_rank}')
logger.info(f"Using Ascend NPU device: {device}")
return device
except Exception as e:
logger.error(f"Failed to setup Ascend device: {e}")
raise
def convert_model_for_ascend(model: torch.nn.Module) -> torch.nn.Module:
"""
Convert model for Ascend NPU compatibility
Args:
model: PyTorch model
Returns:
Model configured for Ascend
"""
if not check_ascend_availability():
return model
try:
import torch_npu
# Apply Ascend-specific optimizations
model = torch_npu.contrib.transfer_to_npu(model)
# Enable Ascend graph mode if available
if hasattr(torch_npu, 'enable_graph_mode'):
torch_npu.enable_graph_mode()
logger.info("Enabled Ascend graph mode")
return model
except Exception as e:
logger.warning(f"Failed to apply Ascend optimizations: {e}")
return model
def optimize_for_ascend(config: Dict[str, Any]) -> Dict[str, Any]:
"""
Optimize configuration for Ascend NPU
Args:
config: Training configuration
Returns:
Optimized configuration
"""
if not check_ascend_availability():
return config
# Create a copy to avoid modifying original
config = config.copy()
# Ascend-specific optimizations
ascend_config = config.get('ascend', {})
# Set backend to hccl for distributed training
if ascend_config.get('backend') == 'hccl':
config['distributed'] = config.get('distributed', {})
config['distributed']['backend'] = 'hccl'
# Enable mixed precision if specified
if ascend_config.get('mixed_precision', False):
config['optimization'] = config.get('optimization', {})
config['optimization']['fp16'] = True
logger.info("Enabled mixed precision for Ascend")
# Set visible devices
if 'visible_devices' in ascend_config:
os.environ['ASCEND_RT_VISIBLE_DEVICES'] = ascend_config['visible_devices']
logger.info(f"Set ASCEND_RT_VISIBLE_DEVICES={ascend_config['visible_devices']}")
return config
def get_ascend_memory_info(device_id: int = 0) -> Tuple[float, float]:
"""
Get Ascend NPU memory information
Args:
device_id: NPU device ID
Returns:
Tuple of (used_memory_gb, total_memory_gb)
"""
if not check_ascend_availability():
return 0.0, 0.0
try:
import torch_npu
# Get memory info
used = torch_npu.memory_allocated(device_id) / 1024**3
total = torch_npu.get_device_properties(device_id).total_memory / 1024**3
return used, total
except Exception as e:
logger.warning(f"Failed to get Ascend memory info: {e}")
return 0.0, 0.0
def clear_ascend_cache():
"""Clear Ascend NPU memory cache"""
if not check_ascend_availability():
return
try:
torch.npu.empty_cache()
logger.debug("Cleared Ascend NPU cache")
except Exception as e:
logger.warning(f"Failed to clear Ascend cache: {e}")
def is_ascend_available() -> bool:
"""
Simple check for Ascend availability
Returns:
bool: True if Ascend NPU is available
"""
return check_ascend_availability()
class AscendDataLoader:
"""
Wrapper for DataLoader with Ascend-specific optimizations
"""
def __init__(self, dataloader, device):
self.dataloader = dataloader
self.device = device
def __iter__(self):
for batch in self.dataloader:
# Move batch to NPU device
if isinstance(batch, dict):
batch = {k: v.to(self.device) if isinstance(v, torch.Tensor) else v
for k, v in batch.items()}
elif isinstance(batch, (list, tuple)):
batch = [v.to(self.device) if isinstance(v, torch.Tensor) else v
for v in batch]
else:
batch = batch.to(self.device)
yield batch
def __len__(self):
return len(self.dataloader)
# Compatibility functions
def npu_synchronize():
"""Synchronize NPU device (compatibility wrapper)"""
if check_ascend_availability():
try:
torch.npu.synchronize()
except Exception as e:
logger.warning(f"NPU synchronize failed: {e}")
elif torch.cuda.is_available():
torch.cuda.synchronize()
def get_device_name(device: torch.device) -> str:
"""
Get device name with Ascend support
Args:
device: PyTorch device
Returns:
str: Device name
"""
if device.type == 'npu':
if check_ascend_availability():
try:
props = torch.npu.get_device_properties(device.index or 0)
return f"Ascend {props.name}"
except:
return "Ascend NPU"
return "NPU (unavailable)"
elif device.type == 'cuda':
return torch.cuda.get_device_name(device.index or 0)
else:
return str(device)