561 lines
24 KiB
Python
561 lines
24 KiB
Python
import os
|
|
import torch
|
|
from dataclasses import dataclass, field
|
|
from typing import Optional, List, Dict, Any
|
|
import logging
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
@dataclass
|
|
class BaseConfig:
|
|
"""
|
|
Base configuration for all trainers and models.
|
|
|
|
Configuration Precedence Order:
|
|
1. Command-line arguments (highest priority)
|
|
2. Model-specific config sections ([m3], [reranker])
|
|
3. [hardware] section (for hardware-specific overrides)
|
|
4. [training] section (general training defaults)
|
|
5. Dataclass field defaults (lowest priority)
|
|
|
|
Example:
|
|
If batch_size is defined in multiple places:
|
|
- CLI arg --batch_size 32 (wins)
|
|
- [m3] batch_size = 16
|
|
- [hardware] per_device_train_batch_size = 8
|
|
- [training] default_batch_size = 4
|
|
- Dataclass default = 2
|
|
|
|
Note: Model-specific configs ([m3], [reranker]) take precedence over
|
|
[hardware] for their respective models to allow fine-grained control.
|
|
"""
|
|
temperature: float = 1.0 # Softmax temperature for loss functions
|
|
use_in_batch_negatives: bool = False # Use in-batch negatives for contrastive loss
|
|
negatives_cross_device: bool = False # Use negatives across devices (DDP)
|
|
label_smoothing: float = 0.0 # Label smoothing for classification losses
|
|
alpha: float = 1.0 # Weight for auxiliary losses
|
|
gradient_accumulation_steps: int = 1 # Steps to accumulate gradients
|
|
num_train_epochs: int = 3 # Number of training epochs
|
|
eval_steps: int = 1 # Evaluate every N epochs
|
|
save_steps: int = 1 # Save checkpoint every N epochs
|
|
warmup_steps: int = 0 # Number of warmup steps for scheduler
|
|
warmup_ratio: float = 0.0 # Warmup ratio for scheduler
|
|
metric_for_best_model: str = "accuracy" # Metric to select best model
|
|
checkpoint_dir: str = "./checkpoints" # Directory for saving checkpoints
|
|
resume_from_checkpoint: str = "" # Path to resume checkpoint
|
|
max_grad_norm: float = 1.0 # Max gradient norm for clipping
|
|
use_amp: bool = False # Use automatic mixed precision
|
|
|
|
# Model configuration
|
|
model_name_or_path: str = "BAAI/bge-m3"
|
|
cache_dir: Optional[str] = None
|
|
|
|
# Data configuration
|
|
train_data_path: str = "./data/train.jsonl"
|
|
eval_data_path: Optional[str] = "./data/eval.jsonl"
|
|
max_seq_length: int = 512
|
|
query_max_length: int = 64
|
|
passage_max_length: int = 512
|
|
|
|
# Training configuration
|
|
output_dir: str = "./output"
|
|
learning_rate: float = 1e-5
|
|
weight_decay: float = 0.01
|
|
adam_epsilon: float = 1e-8
|
|
adam_beta1: float = 0.9
|
|
adam_beta2: float = 0.999
|
|
|
|
# Advanced training configuration
|
|
fp16: bool = True
|
|
bf16: bool = False
|
|
gradient_checkpointing: bool = True
|
|
deepspeed: Optional[str] = None
|
|
local_rank: int = -1
|
|
dataloader_num_workers: int = 4
|
|
dataloader_pin_memory: bool = True
|
|
|
|
# Batch sizes
|
|
per_device_train_batch_size: int = 4
|
|
per_device_eval_batch_size: int = 8
|
|
|
|
# Logging and saving
|
|
logging_dir: str = "./logs"
|
|
# For testing only, in production, set to 100
|
|
logging_steps: int = 100
|
|
save_total_limit: int = 3
|
|
load_best_model_at_end: bool = True
|
|
greater_is_better: bool = False
|
|
evaluation_strategy: str = "steps"
|
|
|
|
# Checkpoint and resume
|
|
overwrite_output_dir: bool = False
|
|
|
|
# Seed
|
|
seed: int = 42
|
|
|
|
# Platform compatibility
|
|
device: str = "auto"
|
|
n_gpu: int = 0
|
|
|
|
def __post_init__(self):
|
|
"""Post-initialization with robust config loading and device setup
|
|
Notes:
|
|
- [hardware] section overrides [training] for overlapping parameters.
|
|
- Model-specific configs ([m3], [reranker]) override [training]/[hardware] for their models.
|
|
"""
|
|
# Load configuration from centralized config system
|
|
config = self._get_config_safely()
|
|
|
|
# Apply configuration values with fallbacks
|
|
self._apply_config_values(config)
|
|
|
|
# Setup device detection
|
|
self._setup_device()
|
|
|
|
def _get_config_safely(self):
|
|
"""Safely get configuration with multiple fallback strategies"""
|
|
try:
|
|
# Try absolute import
|
|
from utils.config_loader import get_config
|
|
return get_config()
|
|
except ImportError:
|
|
try:
|
|
# Try absolute import (for scripts)
|
|
from bge_finetune.utils.config_loader import get_config # type: ignore[import]
|
|
return get_config()
|
|
except ImportError:
|
|
try:
|
|
# Try direct import (when run as script from root)
|
|
from utils.config_loader import get_config
|
|
return get_config()
|
|
except ImportError:
|
|
logger.error("Could not import config_loader, using fallback configuration. Please check your environment and config loader path.")
|
|
return self._create_fallback_config()
|
|
|
|
def _create_fallback_config(self):
|
|
"""Create a minimal fallback configuration object"""
|
|
|
|
class FallbackConfig:
|
|
def __init__(self):
|
|
self._config = {
|
|
'training': {
|
|
'default_batch_size': 4,
|
|
'default_learning_rate': 1e-5,
|
|
'default_num_epochs': 3,
|
|
'default_warmup_ratio': 0.1,
|
|
'default_weight_decay': 0.01,
|
|
'gradient_accumulation_steps': 4
|
|
},
|
|
'model_paths': {
|
|
'cache_dir': './cache/data'
|
|
},
|
|
'data': {
|
|
'max_seq_length': 512,
|
|
'max_query_length': 64,
|
|
'max_passage_length': 512,
|
|
'random_seed': 42
|
|
},
|
|
'hardware': {
|
|
'dataloader_num_workers': 4,
|
|
'dataloader_pin_memory': True,
|
|
'per_device_train_batch_size': 4,
|
|
'per_device_eval_batch_size': 8
|
|
},
|
|
'optimization': {
|
|
'fp16': True,
|
|
'bf16': False,
|
|
'gradient_checkpointing': True,
|
|
'max_grad_norm': 1.0,
|
|
'adam_epsilon': 1e-8,
|
|
'adam_beta1': 0.9,
|
|
'adam_beta2': 0.999
|
|
},
|
|
'checkpoint': {
|
|
'save_steps': 1000,
|
|
'save_total_limit': 3,
|
|
'load_best_model_at_end': True,
|
|
'metric_for_best_model': 'eval_loss',
|
|
'greater_is_better': False
|
|
},
|
|
'logging': {
|
|
'log_dir': './logs'
|
|
}
|
|
}
|
|
|
|
def get(self, key, default=None):
|
|
keys = key.split('.')
|
|
value = self._config
|
|
try:
|
|
for k in keys:
|
|
value = value[k]
|
|
return value
|
|
except (KeyError, TypeError):
|
|
return default
|
|
|
|
def get_section(self, section):
|
|
return self._config.get(section, {})
|
|
|
|
return FallbackConfig()
|
|
|
|
def _apply_config_values(self, config):
|
|
"""
|
|
Apply configuration values with proper precedence order.
|
|
|
|
Precedence (lowest to highest):
|
|
1. Dataclass field defaults (already set)
|
|
2. [training] section values
|
|
3. [hardware] section values (override training)
|
|
4. Model-specific sections (handled in subclasses)
|
|
5. Command-line arguments (handled by training scripts)
|
|
|
|
This method handles steps 2-3. Model-specific configs are applied
|
|
in subclass __post_init__ methods AFTER this base method runs.
|
|
"""
|
|
try:
|
|
# === STEP 2: Apply [training] section (general defaults) ===
|
|
training_config = config.get_section('training')
|
|
if training_config:
|
|
# Only update if value exists in config
|
|
if 'default_num_epochs' in training_config:
|
|
old_value = self.num_train_epochs
|
|
self.num_train_epochs = training_config['default_num_epochs']
|
|
if old_value != self.num_train_epochs:
|
|
logger.debug(f"[training] num_train_epochs: {old_value} -> {self.num_train_epochs}")
|
|
|
|
if 'default_batch_size' in training_config:
|
|
self.per_device_train_batch_size = training_config['default_batch_size']
|
|
logger.debug(f"[training] batch_size set to {self.per_device_train_batch_size}")
|
|
|
|
if 'default_learning_rate' in training_config:
|
|
self.learning_rate = training_config['default_learning_rate']
|
|
|
|
if 'default_warmup_ratio' in training_config:
|
|
self.warmup_ratio = training_config['default_warmup_ratio']
|
|
|
|
if 'gradient_accumulation_steps' in training_config:
|
|
self.gradient_accumulation_steps = training_config['gradient_accumulation_steps']
|
|
|
|
if 'default_weight_decay' in training_config:
|
|
self.weight_decay = training_config['default_weight_decay']
|
|
|
|
# === STEP 3: Apply [hardware] section (overrides training) ===
|
|
hardware_config = config.get_section('hardware')
|
|
if hardware_config:
|
|
# Hardware settings take precedence over training defaults
|
|
if 'dataloader_num_workers' in hardware_config:
|
|
self.dataloader_num_workers = hardware_config['dataloader_num_workers']
|
|
|
|
if 'dataloader_pin_memory' in hardware_config:
|
|
self.dataloader_pin_memory = hardware_config['dataloader_pin_memory']
|
|
|
|
# Batch sizes from hardware override training defaults
|
|
if 'per_device_train_batch_size' in hardware_config:
|
|
old_value = self.per_device_train_batch_size
|
|
self.per_device_train_batch_size = hardware_config['per_device_train_batch_size']
|
|
if old_value != self.per_device_train_batch_size:
|
|
logger.info(f"[hardware] overrides batch_size: {old_value} -> {self.per_device_train_batch_size}")
|
|
|
|
if 'per_device_eval_batch_size' in hardware_config:
|
|
self.per_device_eval_batch_size = hardware_config['per_device_eval_batch_size']
|
|
|
|
# === Apply other configuration sections ===
|
|
|
|
# Model paths
|
|
if self.cache_dir is None:
|
|
cache_dir = config.get('model_paths.cache_dir')
|
|
if cache_dir is not None:
|
|
self.cache_dir = cache_dir
|
|
logger.debug(f"cache_dir set from config: {self.cache_dir}")
|
|
|
|
# Data configuration
|
|
data_config = config.get_section('data')
|
|
if data_config:
|
|
if 'max_seq_length' in data_config:
|
|
self.max_seq_length = data_config['max_seq_length']
|
|
if 'max_query_length' in data_config:
|
|
self.query_max_length = data_config['max_query_length']
|
|
if 'max_passage_length' in data_config:
|
|
self.passage_max_length = data_config['max_passage_length']
|
|
if 'random_seed' in data_config:
|
|
self.seed = data_config['random_seed']
|
|
|
|
# Optimization configuration
|
|
opt_config = config.get_section('optimization')
|
|
if opt_config:
|
|
if 'fp16' in opt_config:
|
|
self.fp16 = opt_config['fp16']
|
|
if 'bf16' in opt_config:
|
|
self.bf16 = opt_config['bf16']
|
|
if 'gradient_checkpointing' in opt_config:
|
|
self.gradient_checkpointing = opt_config['gradient_checkpointing']
|
|
if 'max_grad_norm' in opt_config:
|
|
self.max_grad_norm = opt_config['max_grad_norm']
|
|
if 'adam_epsilon' in opt_config:
|
|
self.adam_epsilon = opt_config['adam_epsilon']
|
|
if 'adam_beta1' in opt_config:
|
|
self.adam_beta1 = opt_config['adam_beta1']
|
|
if 'adam_beta2' in opt_config:
|
|
self.adam_beta2 = opt_config['adam_beta2']
|
|
|
|
# Checkpoint configuration
|
|
checkpoint_config = config.get_section('checkpoint')
|
|
if checkpoint_config:
|
|
if 'save_steps' in checkpoint_config:
|
|
self.save_steps = checkpoint_config['save_steps']
|
|
if 'save_total_limit' in checkpoint_config:
|
|
self.save_total_limit = checkpoint_config['save_total_limit']
|
|
if 'load_best_model_at_end' in checkpoint_config:
|
|
self.load_best_model_at_end = checkpoint_config['load_best_model_at_end']
|
|
if 'metric_for_best_model' in checkpoint_config:
|
|
self.metric_for_best_model = checkpoint_config['metric_for_best_model']
|
|
if 'greater_is_better' in checkpoint_config:
|
|
self.greater_is_better = checkpoint_config['greater_is_better']
|
|
|
|
# Logging configuration
|
|
logging_config = config.get_section('logging')
|
|
if logging_config:
|
|
if 'log_dir' in logging_config:
|
|
self.logging_dir = logging_config['log_dir']
|
|
|
|
except Exception as e:
|
|
logger.warning(f"Error applying configuration values: {e}")
|
|
|
|
def _setup_device(self):
|
|
"""Enhanced device detection with Ascend NPU support"""
|
|
if self.device == "auto":
|
|
device_info = self._detect_available_devices()
|
|
self.device = device_info['device']
|
|
self.n_gpu = device_info['count']
|
|
logger.info(f"Auto-detected device: {self.device} (count: {self.n_gpu})")
|
|
else:
|
|
# Manual device specification
|
|
if self.device == "cuda" and torch.cuda.is_available():
|
|
self.n_gpu = torch.cuda.device_count()
|
|
elif self.device == "npu":
|
|
try:
|
|
import torch_npu # type: ignore[import]
|
|
npu = getattr(torch, 'npu', None)
|
|
if npu is not None and npu.is_available():
|
|
self.n_gpu = npu.device_count()
|
|
else:
|
|
logger.warning("NPU requested but not available, falling back to CPU")
|
|
self.device = "cpu"
|
|
self.n_gpu = 0
|
|
except ImportError:
|
|
logger.warning("NPU requested but torch_npu not installed, falling back to CPU")
|
|
self.device = "cpu"
|
|
self.n_gpu = 0
|
|
else:
|
|
self.n_gpu = 0
|
|
|
|
def _detect_available_devices(self) -> Dict[str, Any]:
|
|
"""Comprehensive device detection with priority ordering"""
|
|
device_priority = ['npu', 'cuda', 'mps', 'cpu']
|
|
available_devices = []
|
|
|
|
# Check Ascend NPU
|
|
try:
|
|
import torch_npu # type: ignore[import]
|
|
npu = getattr(torch, 'npu', None)
|
|
if npu is not None and npu.is_available():
|
|
count = npu.device_count()
|
|
available_devices.append(('npu', count))
|
|
logger.debug(f"Ascend NPU available: {count} devices")
|
|
except ImportError:
|
|
logger.debug("Ascend NPU library not available")
|
|
|
|
# Check CUDA
|
|
if torch.cuda.is_available():
|
|
count = torch.cuda.device_count()
|
|
available_devices.append(('cuda', count))
|
|
logger.debug(f"CUDA available: {count} devices")
|
|
|
|
# Check MPS (Apple Silicon)
|
|
if hasattr(torch, 'backends') and hasattr(torch.backends, 'mps') and hasattr(torch.backends.mps, 'is_available') and torch.backends.mps.is_available(): # type: ignore[attr-defined]
|
|
available_devices.append(('mps', 1))
|
|
logger.debug("Apple MPS available")
|
|
|
|
# CPU always available
|
|
available_devices.append(('cpu', 1))
|
|
|
|
# Select best available device based on priority
|
|
for device_type in device_priority:
|
|
for available_type, count in available_devices:
|
|
if available_type == device_type:
|
|
return {'device': device_type, 'count': count}
|
|
|
|
return {'device': 'cpu', 'count': 1} # Fallback
|
|
|
|
def validate(self):
|
|
"""Comprehensive configuration validation"""
|
|
errors = []
|
|
# Check required paths
|
|
if self.train_data_path and not os.path.exists(self.train_data_path):
|
|
errors.append(f"Training data not found: {self.train_data_path}")
|
|
if self.eval_data_path and not os.path.exists(self.eval_data_path):
|
|
errors.append(f"Evaluation data not found: {self.eval_data_path}")
|
|
# Validate numeric parameters
|
|
if self.learning_rate is None or self.learning_rate <= 0:
|
|
errors.append("Learning rate must be positive and not None")
|
|
if self.num_train_epochs is None or self.num_train_epochs <= 0:
|
|
errors.append("Number of epochs must be positive and not None")
|
|
if self.per_device_train_batch_size is None or self.per_device_train_batch_size <= 0:
|
|
errors.append("Training batch size must be positive and not None")
|
|
if self.gradient_accumulation_steps is None or self.gradient_accumulation_steps <= 0:
|
|
errors.append("Gradient accumulation steps must be positive and not None")
|
|
if self.warmup_ratio is None or self.warmup_ratio < 0 or self.warmup_ratio > 1:
|
|
errors.append("Warmup ratio must be between 0 and 1 and not None")
|
|
# Validate device
|
|
if self.device not in ['auto', 'cpu', 'cuda', 'npu', 'mps']:
|
|
errors.append(f"Invalid device: {self.device}")
|
|
# Check sequence lengths
|
|
if self.max_seq_length is None or self.max_seq_length <= 0:
|
|
errors.append("Maximum sequence length must be positive and not None")
|
|
if self.query_max_length is None or self.query_max_length <= 0:
|
|
errors.append("Maximum query length must be positive and not None")
|
|
if self.passage_max_length is None or self.passage_max_length <= 0:
|
|
errors.append("Maximum passage length must be positive and not None")
|
|
if errors:
|
|
raise ValueError(f"Configuration validation failed:\n" + "\n".join(f" - {error}" for error in errors))
|
|
logger.info("Configuration validation passed")
|
|
return True
|
|
|
|
def to_dict(self) -> Dict[str, Any]:
|
|
"""Convert config to dictionary with type handling"""
|
|
config_dict = {}
|
|
for key, value in self.__dict__.items():
|
|
if isinstance(value, (str, int, float, bool, type(None))):
|
|
config_dict[key] = value
|
|
elif isinstance(value, (list, tuple)):
|
|
config_dict[key] = list(value)
|
|
elif isinstance(value, dict):
|
|
config_dict[key] = dict(value)
|
|
else:
|
|
config_dict[key] = str(value)
|
|
return config_dict
|
|
|
|
def save(self, save_path: str):
|
|
"""Save configuration to file with error handling"""
|
|
try:
|
|
os.makedirs(os.path.dirname(save_path), exist_ok=True)
|
|
|
|
import json
|
|
with open(save_path, 'w', encoding='utf-8') as f:
|
|
json.dump(self.to_dict(), f, indent=2, ensure_ascii=False)
|
|
|
|
logger.info(f"Configuration saved to {save_path}")
|
|
except Exception as e:
|
|
logger.error(f"Failed to save configuration: {e}")
|
|
raise
|
|
|
|
@classmethod
|
|
def load(cls, load_path: str):
|
|
"""Load configuration from file with validation"""
|
|
try:
|
|
import json
|
|
with open(load_path, 'r', encoding='utf-8') as f:
|
|
config_dict = json.load(f)
|
|
|
|
# Filter only valid field names
|
|
import inspect
|
|
sig = inspect.signature(cls)
|
|
valid_keys = set(sig.parameters.keys())
|
|
filtered_dict = {k: v for k, v in config_dict.items() if k in valid_keys}
|
|
|
|
config = cls(**filtered_dict)
|
|
config.validate() # Validate after loading
|
|
|
|
logger.info(f"Configuration loaded from {load_path}")
|
|
return config
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to load configuration from {load_path}: {e}")
|
|
raise
|
|
|
|
@classmethod
|
|
def from_dict(cls, config_dict: Dict[str, Any]):
|
|
"""Create config from dictionary with validation"""
|
|
try:
|
|
# Filter only valid field names
|
|
import inspect
|
|
sig = inspect.signature(cls)
|
|
valid_keys = set(sig.parameters.keys())
|
|
filtered_dict = {k: v for k, v in config_dict.items() if k in valid_keys}
|
|
|
|
config = cls(**filtered_dict)
|
|
config.validate() # Validate after creation
|
|
|
|
return config
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to create configuration from dictionary: {e}")
|
|
raise
|
|
|
|
def update_from_args(self, args):
|
|
"""Update config from command line arguments with validation"""
|
|
try:
|
|
updated_fields = []
|
|
for key, value in vars(args).items():
|
|
if hasattr(self, key) and value is not None:
|
|
old_value = getattr(self, key)
|
|
setattr(self, key, value)
|
|
updated_fields.append(f"{key}: {old_value} -> {value}")
|
|
|
|
if updated_fields:
|
|
logger.info("Updated configuration from command line arguments:")
|
|
for field in updated_fields:
|
|
logger.info(f" {field}")
|
|
|
|
# Validate after updates
|
|
self.validate()
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to update configuration from arguments: {e}")
|
|
raise
|
|
|
|
def get_effective_batch_size(self) -> int:
|
|
"""Calculate effective batch size considering gradient accumulation"""
|
|
return self.per_device_train_batch_size * self.gradient_accumulation_steps * max(1, self.n_gpu)
|
|
|
|
def get_training_steps(self, dataset_size: int) -> int:
|
|
"""Calculate total training steps"""
|
|
steps_per_epoch = dataset_size // self.get_effective_batch_size()
|
|
return steps_per_epoch * self.num_train_epochs
|
|
|
|
def print_effective_config(self):
|
|
"""Print the effective value and its source for each parameter."""
|
|
print("Parameter | Value | Source")
|
|
print("----------|-------|-------")
|
|
for field in self.__dataclass_fields__:
|
|
value = getattr(self, field)
|
|
# For now, just print the value; source tracking can be added if needed
|
|
print(f"{field} | {value} | (see docstring for precedence)")
|
|
|
|
def __str__(self) -> str:
|
|
"""String representation of configuration"""
|
|
return f"BaseConfig(device={self.device}, batch_size={self.per_device_train_batch_size}, lr={self.learning_rate})"
|
|
|
|
def __repr__(self) -> str:
|
|
return self.__str__()
|
|
|
|
def get_device(self):
|
|
"""Return the device string for training (cuda, npu, mps, cpu)"""
|
|
if torch.cuda.is_available():
|
|
return "cuda"
|
|
npu = getattr(torch, "npu", None)
|
|
if npu is not None and hasattr(npu, "is_available") and npu.is_available():
|
|
return "npu"
|
|
backends = getattr(torch, "backends", None)
|
|
if backends is not None and hasattr(backends, "mps") and backends.mps.is_available():
|
|
return "mps"
|
|
return "cpu"
|
|
|
|
def is_npu(self):
|
|
"""Return True if NPU is available and selected as device."""
|
|
npu = getattr(torch, "npu", None)
|
|
if npu is not None and hasattr(npu, "is_available") and npu.is_available():
|
|
device = self.get_device() if hasattr(self, 'get_device') else "cpu"
|
|
return device == "npu"
|
|
return False
|