init commit
This commit is contained in:
560
config/base_config.py
Normal file
560
config/base_config.py
Normal file
@@ -0,0 +1,560 @@
|
||||
import os
|
||||
import torch
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional, List, Dict, Any
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class BaseConfig:
|
||||
"""
|
||||
Base configuration for all trainers and models.
|
||||
|
||||
Configuration Precedence Order:
|
||||
1. Command-line arguments (highest priority)
|
||||
2. Model-specific config sections ([m3], [reranker])
|
||||
3. [hardware] section (for hardware-specific overrides)
|
||||
4. [training] section (general training defaults)
|
||||
5. Dataclass field defaults (lowest priority)
|
||||
|
||||
Example:
|
||||
If batch_size is defined in multiple places:
|
||||
- CLI arg --batch_size 32 (wins)
|
||||
- [m3] batch_size = 16
|
||||
- [hardware] per_device_train_batch_size = 8
|
||||
- [training] default_batch_size = 4
|
||||
- Dataclass default = 2
|
||||
|
||||
Note: Model-specific configs ([m3], [reranker]) take precedence over
|
||||
[hardware] for their respective models to allow fine-grained control.
|
||||
"""
|
||||
temperature: float = 1.0 # Softmax temperature for loss functions
|
||||
use_in_batch_negatives: bool = False # Use in-batch negatives for contrastive loss
|
||||
negatives_cross_device: bool = False # Use negatives across devices (DDP)
|
||||
label_smoothing: float = 0.0 # Label smoothing for classification losses
|
||||
alpha: float = 1.0 # Weight for auxiliary losses
|
||||
gradient_accumulation_steps: int = 1 # Steps to accumulate gradients
|
||||
num_train_epochs: int = 3 # Number of training epochs
|
||||
eval_steps: int = 1 # Evaluate every N epochs
|
||||
save_steps: int = 1 # Save checkpoint every N epochs
|
||||
warmup_steps: int = 0 # Number of warmup steps for scheduler
|
||||
warmup_ratio: float = 0.0 # Warmup ratio for scheduler
|
||||
metric_for_best_model: str = "accuracy" # Metric to select best model
|
||||
checkpoint_dir: str = "./checkpoints" # Directory for saving checkpoints
|
||||
resume_from_checkpoint: str = "" # Path to resume checkpoint
|
||||
max_grad_norm: float = 1.0 # Max gradient norm for clipping
|
||||
use_amp: bool = False # Use automatic mixed precision
|
||||
|
||||
# Model configuration
|
||||
model_name_or_path: str = "BAAI/bge-m3"
|
||||
cache_dir: Optional[str] = None
|
||||
|
||||
# Data configuration
|
||||
train_data_path: str = "./data/train.jsonl"
|
||||
eval_data_path: Optional[str] = "./data/eval.jsonl"
|
||||
max_seq_length: int = 512
|
||||
query_max_length: int = 64
|
||||
passage_max_length: int = 512
|
||||
|
||||
# Training configuration
|
||||
output_dir: str = "./output"
|
||||
learning_rate: float = 1e-5
|
||||
weight_decay: float = 0.01
|
||||
adam_epsilon: float = 1e-8
|
||||
adam_beta1: float = 0.9
|
||||
adam_beta2: float = 0.999
|
||||
|
||||
# Advanced training configuration
|
||||
fp16: bool = True
|
||||
bf16: bool = False
|
||||
gradient_checkpointing: bool = True
|
||||
deepspeed: Optional[str] = None
|
||||
local_rank: int = -1
|
||||
dataloader_num_workers: int = 4
|
||||
dataloader_pin_memory: bool = True
|
||||
|
||||
# Batch sizes
|
||||
per_device_train_batch_size: int = 4
|
||||
per_device_eval_batch_size: int = 8
|
||||
|
||||
# Logging and saving
|
||||
logging_dir: str = "./logs"
|
||||
# For testing only, in production, set to 100
|
||||
logging_steps: int = 100
|
||||
save_total_limit: int = 3
|
||||
load_best_model_at_end: bool = True
|
||||
greater_is_better: bool = False
|
||||
evaluation_strategy: str = "steps"
|
||||
|
||||
# Checkpoint and resume
|
||||
overwrite_output_dir: bool = False
|
||||
|
||||
# Seed
|
||||
seed: int = 42
|
||||
|
||||
# Platform compatibility
|
||||
device: str = "auto"
|
||||
n_gpu: int = 0
|
||||
|
||||
def __post_init__(self):
|
||||
"""Post-initialization with robust config loading and device setup
|
||||
Notes:
|
||||
- [hardware] section overrides [training] for overlapping parameters.
|
||||
- Model-specific configs ([m3], [reranker]) override [training]/[hardware] for their models.
|
||||
"""
|
||||
# Load configuration from centralized config system
|
||||
config = self._get_config_safely()
|
||||
|
||||
# Apply configuration values with fallbacks
|
||||
self._apply_config_values(config)
|
||||
|
||||
# Setup device detection
|
||||
self._setup_device()
|
||||
|
||||
def _get_config_safely(self):
|
||||
"""Safely get configuration with multiple fallback strategies"""
|
||||
try:
|
||||
# Try absolute import
|
||||
from utils.config_loader import get_config
|
||||
return get_config()
|
||||
except ImportError:
|
||||
try:
|
||||
# Try absolute import (for scripts)
|
||||
from bge_finetune.utils.config_loader import get_config # type: ignore[import]
|
||||
return get_config()
|
||||
except ImportError:
|
||||
try:
|
||||
# Try direct import (when run as script from root)
|
||||
from utils.config_loader import get_config
|
||||
return get_config()
|
||||
except ImportError:
|
||||
logger.error("Could not import config_loader, using fallback configuration. Please check your environment and config loader path.")
|
||||
return self._create_fallback_config()
|
||||
|
||||
def _create_fallback_config(self):
|
||||
"""Create a minimal fallback configuration object"""
|
||||
|
||||
class FallbackConfig:
|
||||
def __init__(self):
|
||||
self._config = {
|
||||
'training': {
|
||||
'default_batch_size': 4,
|
||||
'default_learning_rate': 1e-5,
|
||||
'default_num_epochs': 3,
|
||||
'default_warmup_ratio': 0.1,
|
||||
'default_weight_decay': 0.01,
|
||||
'gradient_accumulation_steps': 4
|
||||
},
|
||||
'model_paths': {
|
||||
'cache_dir': './cache/data'
|
||||
},
|
||||
'data': {
|
||||
'max_seq_length': 512,
|
||||
'max_query_length': 64,
|
||||
'max_passage_length': 512,
|
||||
'random_seed': 42
|
||||
},
|
||||
'hardware': {
|
||||
'dataloader_num_workers': 4,
|
||||
'dataloader_pin_memory': True,
|
||||
'per_device_train_batch_size': 4,
|
||||
'per_device_eval_batch_size': 8
|
||||
},
|
||||
'optimization': {
|
||||
'fp16': True,
|
||||
'bf16': False,
|
||||
'gradient_checkpointing': True,
|
||||
'max_grad_norm': 1.0,
|
||||
'adam_epsilon': 1e-8,
|
||||
'adam_beta1': 0.9,
|
||||
'adam_beta2': 0.999
|
||||
},
|
||||
'checkpoint': {
|
||||
'save_steps': 1000,
|
||||
'save_total_limit': 3,
|
||||
'load_best_model_at_end': True,
|
||||
'metric_for_best_model': 'eval_loss',
|
||||
'greater_is_better': False
|
||||
},
|
||||
'logging': {
|
||||
'log_dir': './logs'
|
||||
}
|
||||
}
|
||||
|
||||
def get(self, key, default=None):
|
||||
keys = key.split('.')
|
||||
value = self._config
|
||||
try:
|
||||
for k in keys:
|
||||
value = value[k]
|
||||
return value
|
||||
except (KeyError, TypeError):
|
||||
return default
|
||||
|
||||
def get_section(self, section):
|
||||
return self._config.get(section, {})
|
||||
|
||||
return FallbackConfig()
|
||||
|
||||
def _apply_config_values(self, config):
|
||||
"""
|
||||
Apply configuration values with proper precedence order.
|
||||
|
||||
Precedence (lowest to highest):
|
||||
1. Dataclass field defaults (already set)
|
||||
2. [training] section values
|
||||
3. [hardware] section values (override training)
|
||||
4. Model-specific sections (handled in subclasses)
|
||||
5. Command-line arguments (handled by training scripts)
|
||||
|
||||
This method handles steps 2-3. Model-specific configs are applied
|
||||
in subclass __post_init__ methods AFTER this base method runs.
|
||||
"""
|
||||
try:
|
||||
# === STEP 2: Apply [training] section (general defaults) ===
|
||||
training_config = config.get_section('training')
|
||||
if training_config:
|
||||
# Only update if value exists in config
|
||||
if 'default_num_epochs' in training_config:
|
||||
old_value = self.num_train_epochs
|
||||
self.num_train_epochs = training_config['default_num_epochs']
|
||||
if old_value != self.num_train_epochs:
|
||||
logger.debug(f"[training] num_train_epochs: {old_value} -> {self.num_train_epochs}")
|
||||
|
||||
if 'default_batch_size' in training_config:
|
||||
self.per_device_train_batch_size = training_config['default_batch_size']
|
||||
logger.debug(f"[training] batch_size set to {self.per_device_train_batch_size}")
|
||||
|
||||
if 'default_learning_rate' in training_config:
|
||||
self.learning_rate = training_config['default_learning_rate']
|
||||
|
||||
if 'default_warmup_ratio' in training_config:
|
||||
self.warmup_ratio = training_config['default_warmup_ratio']
|
||||
|
||||
if 'gradient_accumulation_steps' in training_config:
|
||||
self.gradient_accumulation_steps = training_config['gradient_accumulation_steps']
|
||||
|
||||
if 'default_weight_decay' in training_config:
|
||||
self.weight_decay = training_config['default_weight_decay']
|
||||
|
||||
# === STEP 3: Apply [hardware] section (overrides training) ===
|
||||
hardware_config = config.get_section('hardware')
|
||||
if hardware_config:
|
||||
# Hardware settings take precedence over training defaults
|
||||
if 'dataloader_num_workers' in hardware_config:
|
||||
self.dataloader_num_workers = hardware_config['dataloader_num_workers']
|
||||
|
||||
if 'dataloader_pin_memory' in hardware_config:
|
||||
self.dataloader_pin_memory = hardware_config['dataloader_pin_memory']
|
||||
|
||||
# Batch sizes from hardware override training defaults
|
||||
if 'per_device_train_batch_size' in hardware_config:
|
||||
old_value = self.per_device_train_batch_size
|
||||
self.per_device_train_batch_size = hardware_config['per_device_train_batch_size']
|
||||
if old_value != self.per_device_train_batch_size:
|
||||
logger.info(f"[hardware] overrides batch_size: {old_value} -> {self.per_device_train_batch_size}")
|
||||
|
||||
if 'per_device_eval_batch_size' in hardware_config:
|
||||
self.per_device_eval_batch_size = hardware_config['per_device_eval_batch_size']
|
||||
|
||||
# === Apply other configuration sections ===
|
||||
|
||||
# Model paths
|
||||
if self.cache_dir is None:
|
||||
cache_dir = config.get('model_paths.cache_dir')
|
||||
if cache_dir is not None:
|
||||
self.cache_dir = cache_dir
|
||||
logger.debug(f"cache_dir set from config: {self.cache_dir}")
|
||||
|
||||
# Data configuration
|
||||
data_config = config.get_section('data')
|
||||
if data_config:
|
||||
if 'max_seq_length' in data_config:
|
||||
self.max_seq_length = data_config['max_seq_length']
|
||||
if 'max_query_length' in data_config:
|
||||
self.query_max_length = data_config['max_query_length']
|
||||
if 'max_passage_length' in data_config:
|
||||
self.passage_max_length = data_config['max_passage_length']
|
||||
if 'random_seed' in data_config:
|
||||
self.seed = data_config['random_seed']
|
||||
|
||||
# Optimization configuration
|
||||
opt_config = config.get_section('optimization')
|
||||
if opt_config:
|
||||
if 'fp16' in opt_config:
|
||||
self.fp16 = opt_config['fp16']
|
||||
if 'bf16' in opt_config:
|
||||
self.bf16 = opt_config['bf16']
|
||||
if 'gradient_checkpointing' in opt_config:
|
||||
self.gradient_checkpointing = opt_config['gradient_checkpointing']
|
||||
if 'max_grad_norm' in opt_config:
|
||||
self.max_grad_norm = opt_config['max_grad_norm']
|
||||
if 'adam_epsilon' in opt_config:
|
||||
self.adam_epsilon = opt_config['adam_epsilon']
|
||||
if 'adam_beta1' in opt_config:
|
||||
self.adam_beta1 = opt_config['adam_beta1']
|
||||
if 'adam_beta2' in opt_config:
|
||||
self.adam_beta2 = opt_config['adam_beta2']
|
||||
|
||||
# Checkpoint configuration
|
||||
checkpoint_config = config.get_section('checkpoint')
|
||||
if checkpoint_config:
|
||||
if 'save_steps' in checkpoint_config:
|
||||
self.save_steps = checkpoint_config['save_steps']
|
||||
if 'save_total_limit' in checkpoint_config:
|
||||
self.save_total_limit = checkpoint_config['save_total_limit']
|
||||
if 'load_best_model_at_end' in checkpoint_config:
|
||||
self.load_best_model_at_end = checkpoint_config['load_best_model_at_end']
|
||||
if 'metric_for_best_model' in checkpoint_config:
|
||||
self.metric_for_best_model = checkpoint_config['metric_for_best_model']
|
||||
if 'greater_is_better' in checkpoint_config:
|
||||
self.greater_is_better = checkpoint_config['greater_is_better']
|
||||
|
||||
# Logging configuration
|
||||
logging_config = config.get_section('logging')
|
||||
if logging_config:
|
||||
if 'log_dir' in logging_config:
|
||||
self.logging_dir = logging_config['log_dir']
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Error applying configuration values: {e}")
|
||||
|
||||
def _setup_device(self):
|
||||
"""Enhanced device detection with Ascend NPU support"""
|
||||
if self.device == "auto":
|
||||
device_info = self._detect_available_devices()
|
||||
self.device = device_info['device']
|
||||
self.n_gpu = device_info['count']
|
||||
logger.info(f"Auto-detected device: {self.device} (count: {self.n_gpu})")
|
||||
else:
|
||||
# Manual device specification
|
||||
if self.device == "cuda" and torch.cuda.is_available():
|
||||
self.n_gpu = torch.cuda.device_count()
|
||||
elif self.device == "npu":
|
||||
try:
|
||||
import torch_npu # type: ignore[import]
|
||||
npu = getattr(torch, 'npu', None)
|
||||
if npu is not None and npu.is_available():
|
||||
self.n_gpu = npu.device_count()
|
||||
else:
|
||||
logger.warning("NPU requested but not available, falling back to CPU")
|
||||
self.device = "cpu"
|
||||
self.n_gpu = 0
|
||||
except ImportError:
|
||||
logger.warning("NPU requested but torch_npu not installed, falling back to CPU")
|
||||
self.device = "cpu"
|
||||
self.n_gpu = 0
|
||||
else:
|
||||
self.n_gpu = 0
|
||||
|
||||
def _detect_available_devices(self) -> Dict[str, Any]:
|
||||
"""Comprehensive device detection with priority ordering"""
|
||||
device_priority = ['npu', 'cuda', 'mps', 'cpu']
|
||||
available_devices = []
|
||||
|
||||
# Check Ascend NPU
|
||||
try:
|
||||
import torch_npu # type: ignore[import]
|
||||
npu = getattr(torch, 'npu', None)
|
||||
if npu is not None and npu.is_available():
|
||||
count = npu.device_count()
|
||||
available_devices.append(('npu', count))
|
||||
logger.debug(f"Ascend NPU available: {count} devices")
|
||||
except ImportError:
|
||||
logger.debug("Ascend NPU library not available")
|
||||
|
||||
# Check CUDA
|
||||
if torch.cuda.is_available():
|
||||
count = torch.cuda.device_count()
|
||||
available_devices.append(('cuda', count))
|
||||
logger.debug(f"CUDA available: {count} devices")
|
||||
|
||||
# Check MPS (Apple Silicon)
|
||||
if hasattr(torch, 'backends') and hasattr(torch.backends, 'mps') and hasattr(torch.backends.mps, 'is_available') and torch.backends.mps.is_available(): # type: ignore[attr-defined]
|
||||
available_devices.append(('mps', 1))
|
||||
logger.debug("Apple MPS available")
|
||||
|
||||
# CPU always available
|
||||
available_devices.append(('cpu', 1))
|
||||
|
||||
# Select best available device based on priority
|
||||
for device_type in device_priority:
|
||||
for available_type, count in available_devices:
|
||||
if available_type == device_type:
|
||||
return {'device': device_type, 'count': count}
|
||||
|
||||
return {'device': 'cpu', 'count': 1} # Fallback
|
||||
|
||||
def validate(self):
|
||||
"""Comprehensive configuration validation"""
|
||||
errors = []
|
||||
# Check required paths
|
||||
if self.train_data_path and not os.path.exists(self.train_data_path):
|
||||
errors.append(f"Training data not found: {self.train_data_path}")
|
||||
if self.eval_data_path and not os.path.exists(self.eval_data_path):
|
||||
errors.append(f"Evaluation data not found: {self.eval_data_path}")
|
||||
# Validate numeric parameters
|
||||
if self.learning_rate is None or self.learning_rate <= 0:
|
||||
errors.append("Learning rate must be positive and not None")
|
||||
if self.num_train_epochs is None or self.num_train_epochs <= 0:
|
||||
errors.append("Number of epochs must be positive and not None")
|
||||
if self.per_device_train_batch_size is None or self.per_device_train_batch_size <= 0:
|
||||
errors.append("Training batch size must be positive and not None")
|
||||
if self.gradient_accumulation_steps is None or self.gradient_accumulation_steps <= 0:
|
||||
errors.append("Gradient accumulation steps must be positive and not None")
|
||||
if self.warmup_ratio is None or self.warmup_ratio < 0 or self.warmup_ratio > 1:
|
||||
errors.append("Warmup ratio must be between 0 and 1 and not None")
|
||||
# Validate device
|
||||
if self.device not in ['auto', 'cpu', 'cuda', 'npu', 'mps']:
|
||||
errors.append(f"Invalid device: {self.device}")
|
||||
# Check sequence lengths
|
||||
if self.max_seq_length is None or self.max_seq_length <= 0:
|
||||
errors.append("Maximum sequence length must be positive and not None")
|
||||
if self.query_max_length is None or self.query_max_length <= 0:
|
||||
errors.append("Maximum query length must be positive and not None")
|
||||
if self.passage_max_length is None or self.passage_max_length <= 0:
|
||||
errors.append("Maximum passage length must be positive and not None")
|
||||
if errors:
|
||||
raise ValueError(f"Configuration validation failed:\n" + "\n".join(f" - {error}" for error in errors))
|
||||
logger.info("Configuration validation passed")
|
||||
return True
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert config to dictionary with type handling"""
|
||||
config_dict = {}
|
||||
for key, value in self.__dict__.items():
|
||||
if isinstance(value, (str, int, float, bool, type(None))):
|
||||
config_dict[key] = value
|
||||
elif isinstance(value, (list, tuple)):
|
||||
config_dict[key] = list(value)
|
||||
elif isinstance(value, dict):
|
||||
config_dict[key] = dict(value)
|
||||
else:
|
||||
config_dict[key] = str(value)
|
||||
return config_dict
|
||||
|
||||
def save(self, save_path: str):
|
||||
"""Save configuration to file with error handling"""
|
||||
try:
|
||||
os.makedirs(os.path.dirname(save_path), exist_ok=True)
|
||||
|
||||
import json
|
||||
with open(save_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(self.to_dict(), f, indent=2, ensure_ascii=False)
|
||||
|
||||
logger.info(f"Configuration saved to {save_path}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save configuration: {e}")
|
||||
raise
|
||||
|
||||
@classmethod
|
||||
def load(cls, load_path: str):
|
||||
"""Load configuration from file with validation"""
|
||||
try:
|
||||
import json
|
||||
with open(load_path, 'r', encoding='utf-8') as f:
|
||||
config_dict = json.load(f)
|
||||
|
||||
# Filter only valid field names
|
||||
import inspect
|
||||
sig = inspect.signature(cls)
|
||||
valid_keys = set(sig.parameters.keys())
|
||||
filtered_dict = {k: v for k, v in config_dict.items() if k in valid_keys}
|
||||
|
||||
config = cls(**filtered_dict)
|
||||
config.validate() # Validate after loading
|
||||
|
||||
logger.info(f"Configuration loaded from {load_path}")
|
||||
return config
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load configuration from {load_path}: {e}")
|
||||
raise
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, config_dict: Dict[str, Any]):
|
||||
"""Create config from dictionary with validation"""
|
||||
try:
|
||||
# Filter only valid field names
|
||||
import inspect
|
||||
sig = inspect.signature(cls)
|
||||
valid_keys = set(sig.parameters.keys())
|
||||
filtered_dict = {k: v for k, v in config_dict.items() if k in valid_keys}
|
||||
|
||||
config = cls(**filtered_dict)
|
||||
config.validate() # Validate after creation
|
||||
|
||||
return config
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create configuration from dictionary: {e}")
|
||||
raise
|
||||
|
||||
def update_from_args(self, args):
|
||||
"""Update config from command line arguments with validation"""
|
||||
try:
|
||||
updated_fields = []
|
||||
for key, value in vars(args).items():
|
||||
if hasattr(self, key) and value is not None:
|
||||
old_value = getattr(self, key)
|
||||
setattr(self, key, value)
|
||||
updated_fields.append(f"{key}: {old_value} -> {value}")
|
||||
|
||||
if updated_fields:
|
||||
logger.info("Updated configuration from command line arguments:")
|
||||
for field in updated_fields:
|
||||
logger.info(f" {field}")
|
||||
|
||||
# Validate after updates
|
||||
self.validate()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to update configuration from arguments: {e}")
|
||||
raise
|
||||
|
||||
def get_effective_batch_size(self) -> int:
|
||||
"""Calculate effective batch size considering gradient accumulation"""
|
||||
return self.per_device_train_batch_size * self.gradient_accumulation_steps * max(1, self.n_gpu)
|
||||
|
||||
def get_training_steps(self, dataset_size: int) -> int:
|
||||
"""Calculate total training steps"""
|
||||
steps_per_epoch = dataset_size // self.get_effective_batch_size()
|
||||
return steps_per_epoch * self.num_train_epochs
|
||||
|
||||
def print_effective_config(self):
|
||||
"""Print the effective value and its source for each parameter."""
|
||||
print("Parameter | Value | Source")
|
||||
print("----------|-------|-------")
|
||||
for field in self.__dataclass_fields__:
|
||||
value = getattr(self, field)
|
||||
# For now, just print the value; source tracking can be added if needed
|
||||
print(f"{field} | {value} | (see docstring for precedence)")
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""String representation of configuration"""
|
||||
return f"BaseConfig(device={self.device}, batch_size={self.per_device_train_batch_size}, lr={self.learning_rate})"
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return self.__str__()
|
||||
|
||||
def get_device(self):
|
||||
"""Return the device string for training (cuda, npu, mps, cpu)"""
|
||||
if torch.cuda.is_available():
|
||||
return "cuda"
|
||||
npu = getattr(torch, "npu", None)
|
||||
if npu is not None and hasattr(npu, "is_available") and npu.is_available():
|
||||
return "npu"
|
||||
backends = getattr(torch, "backends", None)
|
||||
if backends is not None and hasattr(backends, "mps") and backends.mps.is_available():
|
||||
return "mps"
|
||||
return "cpu"
|
||||
|
||||
def is_npu(self):
|
||||
"""Return True if NPU is available and selected as device."""
|
||||
npu = getattr(torch, "npu", None)
|
||||
if npu is not None and hasattr(npu, "is_available") and npu.is_available():
|
||||
device = self.get_device() if hasattr(self, 'get_device') else "cpu"
|
||||
return device == "npu"
|
||||
return False
|
||||
Reference in New Issue
Block a user