init commit
This commit is contained in:
271
utils/ascend_support.py
Normal file
271
utils/ascend_support.py
Normal file
@@ -0,0 +1,271 @@
|
||||
"""
|
||||
Huawei Ascend NPU support utilities for BGE fine-tuning
|
||||
|
||||
This module provides compatibility layer for running BGE models on Ascend NPUs.
|
||||
Requires torch_npu to be installed: pip install torch_npu
|
||||
"""
|
||||
import os
|
||||
import logging
|
||||
from typing import Optional, Dict, Any, Tuple
|
||||
import torch
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Global flag to track Ascend availability
|
||||
_ASCEND_AVAILABLE = False
|
||||
_ASCEND_CHECK_DONE = False
|
||||
|
||||
|
||||
def check_ascend_availability() -> bool:
|
||||
"""
|
||||
Check if Ascend NPU is available
|
||||
|
||||
Returns:
|
||||
bool: True if Ascend NPU is available and properly configured
|
||||
"""
|
||||
global _ASCEND_AVAILABLE, _ASCEND_CHECK_DONE
|
||||
|
||||
if _ASCEND_CHECK_DONE:
|
||||
return _ASCEND_AVAILABLE
|
||||
|
||||
_ASCEND_CHECK_DONE = True
|
||||
|
||||
try:
|
||||
import torch_npu
|
||||
|
||||
# Check if NPU is available
|
||||
if hasattr(torch, 'npu') and torch.npu.is_available():
|
||||
_ASCEND_AVAILABLE = True
|
||||
npu_count = torch.npu.device_count()
|
||||
logger.info(f"Ascend NPU available with {npu_count} devices")
|
||||
|
||||
# Log NPU information
|
||||
for i in range(npu_count):
|
||||
props = torch.npu.get_device_properties(i)
|
||||
logger.info(f"NPU {i}: {props.name}, Total Memory: {props.total_memory / 1024**3:.2f} GB")
|
||||
else:
|
||||
logger.info("Ascend NPU not available")
|
||||
|
||||
except ImportError:
|
||||
logger.debug("torch_npu not installed. Ascend NPU support disabled.")
|
||||
except Exception as e:
|
||||
logger.warning(f"Error checking Ascend availability: {e}")
|
||||
|
||||
return _ASCEND_AVAILABLE
|
||||
|
||||
|
||||
def get_ascend_backend() -> str:
|
||||
"""
|
||||
Get the appropriate backend for Ascend NPU
|
||||
|
||||
Returns:
|
||||
str: 'hccl' for Ascend, 'nccl' for CUDA, 'gloo' for CPU
|
||||
"""
|
||||
if check_ascend_availability():
|
||||
return 'hccl'
|
||||
elif torch.cuda.is_available():
|
||||
return 'nccl'
|
||||
else:
|
||||
return 'gloo'
|
||||
|
||||
|
||||
def setup_ascend_device(local_rank: int = 0) -> torch.device:
|
||||
"""
|
||||
Setup Ascend NPU device
|
||||
|
||||
Args:
|
||||
local_rank: Local rank for distributed training
|
||||
|
||||
Returns:
|
||||
torch.device: Configured device
|
||||
"""
|
||||
if not check_ascend_availability():
|
||||
raise RuntimeError("Ascend NPU not available")
|
||||
|
||||
try:
|
||||
# Set NPU device
|
||||
torch.npu.set_device(local_rank)
|
||||
device = torch.device(f'npu:{local_rank}')
|
||||
logger.info(f"Using Ascend NPU device: {device}")
|
||||
return device
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to setup Ascend device: {e}")
|
||||
raise
|
||||
|
||||
|
||||
def convert_model_for_ascend(model: torch.nn.Module) -> torch.nn.Module:
|
||||
"""
|
||||
Convert model for Ascend NPU compatibility
|
||||
|
||||
Args:
|
||||
model: PyTorch model
|
||||
|
||||
Returns:
|
||||
Model configured for Ascend
|
||||
"""
|
||||
if not check_ascend_availability():
|
||||
return model
|
||||
|
||||
try:
|
||||
import torch_npu
|
||||
|
||||
# Apply Ascend-specific optimizations
|
||||
model = torch_npu.contrib.transfer_to_npu(model)
|
||||
|
||||
# Enable Ascend graph mode if available
|
||||
if hasattr(torch_npu, 'enable_graph_mode'):
|
||||
torch_npu.enable_graph_mode()
|
||||
logger.info("Enabled Ascend graph mode")
|
||||
|
||||
return model
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to apply Ascend optimizations: {e}")
|
||||
return model
|
||||
|
||||
|
||||
def optimize_for_ascend(config: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
Optimize configuration for Ascend NPU
|
||||
|
||||
Args:
|
||||
config: Training configuration
|
||||
|
||||
Returns:
|
||||
Optimized configuration
|
||||
"""
|
||||
if not check_ascend_availability():
|
||||
return config
|
||||
|
||||
# Create a copy to avoid modifying original
|
||||
config = config.copy()
|
||||
|
||||
# Ascend-specific optimizations
|
||||
ascend_config = config.get('ascend', {})
|
||||
|
||||
# Set backend to hccl for distributed training
|
||||
if ascend_config.get('backend') == 'hccl':
|
||||
config['distributed'] = config.get('distributed', {})
|
||||
config['distributed']['backend'] = 'hccl'
|
||||
|
||||
# Enable mixed precision if specified
|
||||
if ascend_config.get('mixed_precision', False):
|
||||
config['optimization'] = config.get('optimization', {})
|
||||
config['optimization']['fp16'] = True
|
||||
logger.info("Enabled mixed precision for Ascend")
|
||||
|
||||
# Set visible devices
|
||||
if 'visible_devices' in ascend_config:
|
||||
os.environ['ASCEND_RT_VISIBLE_DEVICES'] = ascend_config['visible_devices']
|
||||
logger.info(f"Set ASCEND_RT_VISIBLE_DEVICES={ascend_config['visible_devices']}")
|
||||
|
||||
return config
|
||||
|
||||
|
||||
def get_ascend_memory_info(device_id: int = 0) -> Tuple[float, float]:
|
||||
"""
|
||||
Get Ascend NPU memory information
|
||||
|
||||
Args:
|
||||
device_id: NPU device ID
|
||||
|
||||
Returns:
|
||||
Tuple of (used_memory_gb, total_memory_gb)
|
||||
"""
|
||||
if not check_ascend_availability():
|
||||
return 0.0, 0.0
|
||||
|
||||
try:
|
||||
import torch_npu
|
||||
|
||||
# Get memory info
|
||||
used = torch_npu.memory_allocated(device_id) / 1024**3
|
||||
total = torch_npu.get_device_properties(device_id).total_memory / 1024**3
|
||||
|
||||
return used, total
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get Ascend memory info: {e}")
|
||||
return 0.0, 0.0
|
||||
|
||||
|
||||
def clear_ascend_cache():
|
||||
"""Clear Ascend NPU memory cache"""
|
||||
if not check_ascend_availability():
|
||||
return
|
||||
|
||||
try:
|
||||
torch.npu.empty_cache()
|
||||
logger.debug("Cleared Ascend NPU cache")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to clear Ascend cache: {e}")
|
||||
|
||||
|
||||
def is_ascend_available() -> bool:
|
||||
"""
|
||||
Simple check for Ascend availability
|
||||
|
||||
Returns:
|
||||
bool: True if Ascend NPU is available
|
||||
"""
|
||||
return check_ascend_availability()
|
||||
|
||||
|
||||
class AscendDataLoader:
|
||||
"""
|
||||
Wrapper for DataLoader with Ascend-specific optimizations
|
||||
"""
|
||||
|
||||
def __init__(self, dataloader, device):
|
||||
self.dataloader = dataloader
|
||||
self.device = device
|
||||
|
||||
def __iter__(self):
|
||||
for batch in self.dataloader:
|
||||
# Move batch to NPU device
|
||||
if isinstance(batch, dict):
|
||||
batch = {k: v.to(self.device) if isinstance(v, torch.Tensor) else v
|
||||
for k, v in batch.items()}
|
||||
elif isinstance(batch, (list, tuple)):
|
||||
batch = [v.to(self.device) if isinstance(v, torch.Tensor) else v
|
||||
for v in batch]
|
||||
else:
|
||||
batch = batch.to(self.device)
|
||||
yield batch
|
||||
|
||||
def __len__(self):
|
||||
return len(self.dataloader)
|
||||
|
||||
|
||||
# Compatibility functions
|
||||
def npu_synchronize():
|
||||
"""Synchronize NPU device (compatibility wrapper)"""
|
||||
if check_ascend_availability():
|
||||
try:
|
||||
torch.npu.synchronize()
|
||||
except Exception as e:
|
||||
logger.warning(f"NPU synchronize failed: {e}")
|
||||
elif torch.cuda.is_available():
|
||||
torch.cuda.synchronize()
|
||||
|
||||
|
||||
def get_device_name(device: torch.device) -> str:
|
||||
"""
|
||||
Get device name with Ascend support
|
||||
|
||||
Args:
|
||||
device: PyTorch device
|
||||
|
||||
Returns:
|
||||
str: Device name
|
||||
"""
|
||||
if device.type == 'npu':
|
||||
if check_ascend_availability():
|
||||
try:
|
||||
props = torch.npu.get_device_properties(device.index or 0)
|
||||
return f"Ascend {props.name}"
|
||||
except:
|
||||
return "Ascend NPU"
|
||||
return "NPU (unavailable)"
|
||||
elif device.type == 'cuda':
|
||||
return torch.cuda.get_device_name(device.index or 0)
|
||||
else:
|
||||
return str(device)
|
||||
Reference in New Issue
Block a user