271 lines
7.4 KiB
Python
271 lines
7.4 KiB
Python
"""
|
|
Huawei Ascend NPU support utilities for BGE fine-tuning
|
|
|
|
This module provides compatibility layer for running BGE models on Ascend NPUs.
|
|
Requires torch_npu to be installed: pip install torch_npu
|
|
"""
|
|
import os
|
|
import logging
|
|
from typing import Optional, Dict, Any, Tuple
|
|
import torch
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Global flag to track Ascend availability
|
|
_ASCEND_AVAILABLE = False
|
|
_ASCEND_CHECK_DONE = False
|
|
|
|
|
|
def check_ascend_availability() -> bool:
|
|
"""
|
|
Check if Ascend NPU is available
|
|
|
|
Returns:
|
|
bool: True if Ascend NPU is available and properly configured
|
|
"""
|
|
global _ASCEND_AVAILABLE, _ASCEND_CHECK_DONE
|
|
|
|
if _ASCEND_CHECK_DONE:
|
|
return _ASCEND_AVAILABLE
|
|
|
|
_ASCEND_CHECK_DONE = True
|
|
|
|
try:
|
|
import torch_npu
|
|
|
|
# Check if NPU is available
|
|
if hasattr(torch, 'npu') and torch.npu.is_available():
|
|
_ASCEND_AVAILABLE = True
|
|
npu_count = torch.npu.device_count()
|
|
logger.info(f"Ascend NPU available with {npu_count} devices")
|
|
|
|
# Log NPU information
|
|
for i in range(npu_count):
|
|
props = torch.npu.get_device_properties(i)
|
|
logger.info(f"NPU {i}: {props.name}, Total Memory: {props.total_memory / 1024**3:.2f} GB")
|
|
else:
|
|
logger.info("Ascend NPU not available")
|
|
|
|
except ImportError:
|
|
logger.debug("torch_npu not installed. Ascend NPU support disabled.")
|
|
except Exception as e:
|
|
logger.warning(f"Error checking Ascend availability: {e}")
|
|
|
|
return _ASCEND_AVAILABLE
|
|
|
|
|
|
def get_ascend_backend() -> str:
|
|
"""
|
|
Get the appropriate backend for Ascend NPU
|
|
|
|
Returns:
|
|
str: 'hccl' for Ascend, 'nccl' for CUDA, 'gloo' for CPU
|
|
"""
|
|
if check_ascend_availability():
|
|
return 'hccl'
|
|
elif torch.cuda.is_available():
|
|
return 'nccl'
|
|
else:
|
|
return 'gloo'
|
|
|
|
|
|
def setup_ascend_device(local_rank: int = 0) -> torch.device:
|
|
"""
|
|
Setup Ascend NPU device
|
|
|
|
Args:
|
|
local_rank: Local rank for distributed training
|
|
|
|
Returns:
|
|
torch.device: Configured device
|
|
"""
|
|
if not check_ascend_availability():
|
|
raise RuntimeError("Ascend NPU not available")
|
|
|
|
try:
|
|
# Set NPU device
|
|
torch.npu.set_device(local_rank)
|
|
device = torch.device(f'npu:{local_rank}')
|
|
logger.info(f"Using Ascend NPU device: {device}")
|
|
return device
|
|
except Exception as e:
|
|
logger.error(f"Failed to setup Ascend device: {e}")
|
|
raise
|
|
|
|
|
|
def convert_model_for_ascend(model: torch.nn.Module) -> torch.nn.Module:
|
|
"""
|
|
Convert model for Ascend NPU compatibility
|
|
|
|
Args:
|
|
model: PyTorch model
|
|
|
|
Returns:
|
|
Model configured for Ascend
|
|
"""
|
|
if not check_ascend_availability():
|
|
return model
|
|
|
|
try:
|
|
import torch_npu
|
|
|
|
# Apply Ascend-specific optimizations
|
|
model = torch_npu.contrib.transfer_to_npu(model)
|
|
|
|
# Enable Ascend graph mode if available
|
|
if hasattr(torch_npu, 'enable_graph_mode'):
|
|
torch_npu.enable_graph_mode()
|
|
logger.info("Enabled Ascend graph mode")
|
|
|
|
return model
|
|
except Exception as e:
|
|
logger.warning(f"Failed to apply Ascend optimizations: {e}")
|
|
return model
|
|
|
|
|
|
def optimize_for_ascend(config: Dict[str, Any]) -> Dict[str, Any]:
|
|
"""
|
|
Optimize configuration for Ascend NPU
|
|
|
|
Args:
|
|
config: Training configuration
|
|
|
|
Returns:
|
|
Optimized configuration
|
|
"""
|
|
if not check_ascend_availability():
|
|
return config
|
|
|
|
# Create a copy to avoid modifying original
|
|
config = config.copy()
|
|
|
|
# Ascend-specific optimizations
|
|
ascend_config = config.get('ascend', {})
|
|
|
|
# Set backend to hccl for distributed training
|
|
if ascend_config.get('backend') == 'hccl':
|
|
config['distributed'] = config.get('distributed', {})
|
|
config['distributed']['backend'] = 'hccl'
|
|
|
|
# Enable mixed precision if specified
|
|
if ascend_config.get('mixed_precision', False):
|
|
config['optimization'] = config.get('optimization', {})
|
|
config['optimization']['fp16'] = True
|
|
logger.info("Enabled mixed precision for Ascend")
|
|
|
|
# Set visible devices
|
|
if 'visible_devices' in ascend_config:
|
|
os.environ['ASCEND_RT_VISIBLE_DEVICES'] = ascend_config['visible_devices']
|
|
logger.info(f"Set ASCEND_RT_VISIBLE_DEVICES={ascend_config['visible_devices']}")
|
|
|
|
return config
|
|
|
|
|
|
def get_ascend_memory_info(device_id: int = 0) -> Tuple[float, float]:
|
|
"""
|
|
Get Ascend NPU memory information
|
|
|
|
Args:
|
|
device_id: NPU device ID
|
|
|
|
Returns:
|
|
Tuple of (used_memory_gb, total_memory_gb)
|
|
"""
|
|
if not check_ascend_availability():
|
|
return 0.0, 0.0
|
|
|
|
try:
|
|
import torch_npu
|
|
|
|
# Get memory info
|
|
used = torch_npu.memory_allocated(device_id) / 1024**3
|
|
total = torch_npu.get_device_properties(device_id).total_memory / 1024**3
|
|
|
|
return used, total
|
|
except Exception as e:
|
|
logger.warning(f"Failed to get Ascend memory info: {e}")
|
|
return 0.0, 0.0
|
|
|
|
|
|
def clear_ascend_cache():
|
|
"""Clear Ascend NPU memory cache"""
|
|
if not check_ascend_availability():
|
|
return
|
|
|
|
try:
|
|
torch.npu.empty_cache()
|
|
logger.debug("Cleared Ascend NPU cache")
|
|
except Exception as e:
|
|
logger.warning(f"Failed to clear Ascend cache: {e}")
|
|
|
|
|
|
def is_ascend_available() -> bool:
|
|
"""
|
|
Simple check for Ascend availability
|
|
|
|
Returns:
|
|
bool: True if Ascend NPU is available
|
|
"""
|
|
return check_ascend_availability()
|
|
|
|
|
|
class AscendDataLoader:
|
|
"""
|
|
Wrapper for DataLoader with Ascend-specific optimizations
|
|
"""
|
|
|
|
def __init__(self, dataloader, device):
|
|
self.dataloader = dataloader
|
|
self.device = device
|
|
|
|
def __iter__(self):
|
|
for batch in self.dataloader:
|
|
# Move batch to NPU device
|
|
if isinstance(batch, dict):
|
|
batch = {k: v.to(self.device) if isinstance(v, torch.Tensor) else v
|
|
for k, v in batch.items()}
|
|
elif isinstance(batch, (list, tuple)):
|
|
batch = [v.to(self.device) if isinstance(v, torch.Tensor) else v
|
|
for v in batch]
|
|
else:
|
|
batch = batch.to(self.device)
|
|
yield batch
|
|
|
|
def __len__(self):
|
|
return len(self.dataloader)
|
|
|
|
|
|
# Compatibility functions
|
|
def npu_synchronize():
|
|
"""Synchronize NPU device (compatibility wrapper)"""
|
|
if check_ascend_availability():
|
|
try:
|
|
torch.npu.synchronize()
|
|
except Exception as e:
|
|
logger.warning(f"NPU synchronize failed: {e}")
|
|
elif torch.cuda.is_available():
|
|
torch.cuda.synchronize()
|
|
|
|
|
|
def get_device_name(device: torch.device) -> str:
|
|
"""
|
|
Get device name with Ascend support
|
|
|
|
Args:
|
|
device: PyTorch device
|
|
|
|
Returns:
|
|
str: Device name
|
|
"""
|
|
if device.type == 'npu':
|
|
if check_ascend_availability():
|
|
try:
|
|
props = torch.npu.get_device_properties(device.index or 0)
|
|
return f"Ascend {props.name}"
|
|
except:
|
|
return "Ascend NPU"
|
|
return "NPU (unavailable)"
|
|
elif device.type == 'cuda':
|
|
return torch.cuda.get_device_name(device.index or 0)
|
|
else:
|
|
return str(device) |