bge_finetune/config/base_config.py
2025-07-22 16:55:25 +08:00

561 lines
24 KiB
Python

import os
import torch
from dataclasses import dataclass, field
from typing import Optional, List, Dict, Any
import logging
logger = logging.getLogger(__name__)
@dataclass
class BaseConfig:
"""
Base configuration for all trainers and models.
Configuration Precedence Order:
1. Command-line arguments (highest priority)
2. Model-specific config sections ([m3], [reranker])
3. [hardware] section (for hardware-specific overrides)
4. [training] section (general training defaults)
5. Dataclass field defaults (lowest priority)
Example:
If batch_size is defined in multiple places:
- CLI arg --batch_size 32 (wins)
- [m3] batch_size = 16
- [hardware] per_device_train_batch_size = 8
- [training] default_batch_size = 4
- Dataclass default = 2
Note: Model-specific configs ([m3], [reranker]) take precedence over
[hardware] for their respective models to allow fine-grained control.
"""
temperature: float = 1.0 # Softmax temperature for loss functions
use_in_batch_negatives: bool = False # Use in-batch negatives for contrastive loss
negatives_cross_device: bool = False # Use negatives across devices (DDP)
label_smoothing: float = 0.0 # Label smoothing for classification losses
alpha: float = 1.0 # Weight for auxiliary losses
gradient_accumulation_steps: int = 1 # Steps to accumulate gradients
num_train_epochs: int = 3 # Number of training epochs
eval_steps: int = 1 # Evaluate every N epochs
save_steps: int = 1 # Save checkpoint every N epochs
warmup_steps: int = 0 # Number of warmup steps for scheduler
warmup_ratio: float = 0.0 # Warmup ratio for scheduler
metric_for_best_model: str = "accuracy" # Metric to select best model
checkpoint_dir: str = "./checkpoints" # Directory for saving checkpoints
resume_from_checkpoint: str = "" # Path to resume checkpoint
max_grad_norm: float = 1.0 # Max gradient norm for clipping
use_amp: bool = False # Use automatic mixed precision
# Model configuration
model_name_or_path: str = "BAAI/bge-m3"
cache_dir: Optional[str] = None
# Data configuration
train_data_path: str = "./data/train.jsonl"
eval_data_path: Optional[str] = "./data/eval.jsonl"
max_seq_length: int = 512
query_max_length: int = 64
passage_max_length: int = 512
# Training configuration
output_dir: str = "./output"
learning_rate: float = 1e-5
weight_decay: float = 0.01
adam_epsilon: float = 1e-8
adam_beta1: float = 0.9
adam_beta2: float = 0.999
# Advanced training configuration
fp16: bool = True
bf16: bool = False
gradient_checkpointing: bool = True
deepspeed: Optional[str] = None
local_rank: int = -1
dataloader_num_workers: int = 4
dataloader_pin_memory: bool = True
# Batch sizes
per_device_train_batch_size: int = 4
per_device_eval_batch_size: int = 8
# Logging and saving
logging_dir: str = "./logs"
# For testing only, in production, set to 100
logging_steps: int = 100
save_total_limit: int = 3
load_best_model_at_end: bool = True
greater_is_better: bool = False
evaluation_strategy: str = "steps"
# Checkpoint and resume
overwrite_output_dir: bool = False
# Seed
seed: int = 42
# Platform compatibility
device: str = "auto"
n_gpu: int = 0
def __post_init__(self):
"""Post-initialization with robust config loading and device setup
Notes:
- [hardware] section overrides [training] for overlapping parameters.
- Model-specific configs ([m3], [reranker]) override [training]/[hardware] for their models.
"""
# Load configuration from centralized config system
config = self._get_config_safely()
# Apply configuration values with fallbacks
self._apply_config_values(config)
# Setup device detection
self._setup_device()
def _get_config_safely(self):
"""Safely get configuration with multiple fallback strategies"""
try:
# Try absolute import
from utils.config_loader import get_config
return get_config()
except ImportError:
try:
# Try absolute import (for scripts)
from bge_finetune.utils.config_loader import get_config # type: ignore[import]
return get_config()
except ImportError:
try:
# Try direct import (when run as script from root)
from utils.config_loader import get_config
return get_config()
except ImportError:
logger.error("Could not import config_loader, using fallback configuration. Please check your environment and config loader path.")
return self._create_fallback_config()
def _create_fallback_config(self):
"""Create a minimal fallback configuration object"""
class FallbackConfig:
def __init__(self):
self._config = {
'training': {
'default_batch_size': 4,
'default_learning_rate': 1e-5,
'default_num_epochs': 3,
'default_warmup_ratio': 0.1,
'default_weight_decay': 0.01,
'gradient_accumulation_steps': 4
},
'model_paths': {
'cache_dir': './cache/data'
},
'data': {
'max_seq_length': 512,
'max_query_length': 64,
'max_passage_length': 512,
'random_seed': 42
},
'hardware': {
'dataloader_num_workers': 4,
'dataloader_pin_memory': True,
'per_device_train_batch_size': 4,
'per_device_eval_batch_size': 8
},
'optimization': {
'fp16': True,
'bf16': False,
'gradient_checkpointing': True,
'max_grad_norm': 1.0,
'adam_epsilon': 1e-8,
'adam_beta1': 0.9,
'adam_beta2': 0.999
},
'checkpoint': {
'save_steps': 1000,
'save_total_limit': 3,
'load_best_model_at_end': True,
'metric_for_best_model': 'eval_loss',
'greater_is_better': False
},
'logging': {
'log_dir': './logs'
}
}
def get(self, key, default=None):
keys = key.split('.')
value = self._config
try:
for k in keys:
value = value[k]
return value
except (KeyError, TypeError):
return default
def get_section(self, section):
return self._config.get(section, {})
return FallbackConfig()
def _apply_config_values(self, config):
"""
Apply configuration values with proper precedence order.
Precedence (lowest to highest):
1. Dataclass field defaults (already set)
2. [training] section values
3. [hardware] section values (override training)
4. Model-specific sections (handled in subclasses)
5. Command-line arguments (handled by training scripts)
This method handles steps 2-3. Model-specific configs are applied
in subclass __post_init__ methods AFTER this base method runs.
"""
try:
# === STEP 2: Apply [training] section (general defaults) ===
training_config = config.get_section('training')
if training_config:
# Only update if value exists in config
if 'default_num_epochs' in training_config:
old_value = self.num_train_epochs
self.num_train_epochs = training_config['default_num_epochs']
if old_value != self.num_train_epochs:
logger.debug(f"[training] num_train_epochs: {old_value} -> {self.num_train_epochs}")
if 'default_batch_size' in training_config:
self.per_device_train_batch_size = training_config['default_batch_size']
logger.debug(f"[training] batch_size set to {self.per_device_train_batch_size}")
if 'default_learning_rate' in training_config:
self.learning_rate = training_config['default_learning_rate']
if 'default_warmup_ratio' in training_config:
self.warmup_ratio = training_config['default_warmup_ratio']
if 'gradient_accumulation_steps' in training_config:
self.gradient_accumulation_steps = training_config['gradient_accumulation_steps']
if 'default_weight_decay' in training_config:
self.weight_decay = training_config['default_weight_decay']
# === STEP 3: Apply [hardware] section (overrides training) ===
hardware_config = config.get_section('hardware')
if hardware_config:
# Hardware settings take precedence over training defaults
if 'dataloader_num_workers' in hardware_config:
self.dataloader_num_workers = hardware_config['dataloader_num_workers']
if 'dataloader_pin_memory' in hardware_config:
self.dataloader_pin_memory = hardware_config['dataloader_pin_memory']
# Batch sizes from hardware override training defaults
if 'per_device_train_batch_size' in hardware_config:
old_value = self.per_device_train_batch_size
self.per_device_train_batch_size = hardware_config['per_device_train_batch_size']
if old_value != self.per_device_train_batch_size:
logger.info(f"[hardware] overrides batch_size: {old_value} -> {self.per_device_train_batch_size}")
if 'per_device_eval_batch_size' in hardware_config:
self.per_device_eval_batch_size = hardware_config['per_device_eval_batch_size']
# === Apply other configuration sections ===
# Model paths
if self.cache_dir is None:
cache_dir = config.get('model_paths.cache_dir')
if cache_dir is not None:
self.cache_dir = cache_dir
logger.debug(f"cache_dir set from config: {self.cache_dir}")
# Data configuration
data_config = config.get_section('data')
if data_config:
if 'max_seq_length' in data_config:
self.max_seq_length = data_config['max_seq_length']
if 'max_query_length' in data_config:
self.query_max_length = data_config['max_query_length']
if 'max_passage_length' in data_config:
self.passage_max_length = data_config['max_passage_length']
if 'random_seed' in data_config:
self.seed = data_config['random_seed']
# Optimization configuration
opt_config = config.get_section('optimization')
if opt_config:
if 'fp16' in opt_config:
self.fp16 = opt_config['fp16']
if 'bf16' in opt_config:
self.bf16 = opt_config['bf16']
if 'gradient_checkpointing' in opt_config:
self.gradient_checkpointing = opt_config['gradient_checkpointing']
if 'max_grad_norm' in opt_config:
self.max_grad_norm = opt_config['max_grad_norm']
if 'adam_epsilon' in opt_config:
self.adam_epsilon = opt_config['adam_epsilon']
if 'adam_beta1' in opt_config:
self.adam_beta1 = opt_config['adam_beta1']
if 'adam_beta2' in opt_config:
self.adam_beta2 = opt_config['adam_beta2']
# Checkpoint configuration
checkpoint_config = config.get_section('checkpoint')
if checkpoint_config:
if 'save_steps' in checkpoint_config:
self.save_steps = checkpoint_config['save_steps']
if 'save_total_limit' in checkpoint_config:
self.save_total_limit = checkpoint_config['save_total_limit']
if 'load_best_model_at_end' in checkpoint_config:
self.load_best_model_at_end = checkpoint_config['load_best_model_at_end']
if 'metric_for_best_model' in checkpoint_config:
self.metric_for_best_model = checkpoint_config['metric_for_best_model']
if 'greater_is_better' in checkpoint_config:
self.greater_is_better = checkpoint_config['greater_is_better']
# Logging configuration
logging_config = config.get_section('logging')
if logging_config:
if 'log_dir' in logging_config:
self.logging_dir = logging_config['log_dir']
except Exception as e:
logger.warning(f"Error applying configuration values: {e}")
def _setup_device(self):
"""Enhanced device detection with Ascend NPU support"""
if self.device == "auto":
device_info = self._detect_available_devices()
self.device = device_info['device']
self.n_gpu = device_info['count']
logger.info(f"Auto-detected device: {self.device} (count: {self.n_gpu})")
else:
# Manual device specification
if self.device == "cuda" and torch.cuda.is_available():
self.n_gpu = torch.cuda.device_count()
elif self.device == "npu":
try:
import torch_npu # type: ignore[import]
npu = getattr(torch, 'npu', None)
if npu is not None and npu.is_available():
self.n_gpu = npu.device_count()
else:
logger.warning("NPU requested but not available, falling back to CPU")
self.device = "cpu"
self.n_gpu = 0
except ImportError:
logger.warning("NPU requested but torch_npu not installed, falling back to CPU")
self.device = "cpu"
self.n_gpu = 0
else:
self.n_gpu = 0
def _detect_available_devices(self) -> Dict[str, Any]:
"""Comprehensive device detection with priority ordering"""
device_priority = ['npu', 'cuda', 'mps', 'cpu']
available_devices = []
# Check Ascend NPU
try:
import torch_npu # type: ignore[import]
npu = getattr(torch, 'npu', None)
if npu is not None and npu.is_available():
count = npu.device_count()
available_devices.append(('npu', count))
logger.debug(f"Ascend NPU available: {count} devices")
except ImportError:
logger.debug("Ascend NPU library not available")
# Check CUDA
if torch.cuda.is_available():
count = torch.cuda.device_count()
available_devices.append(('cuda', count))
logger.debug(f"CUDA available: {count} devices")
# Check MPS (Apple Silicon)
if hasattr(torch, 'backends') and hasattr(torch.backends, 'mps') and hasattr(torch.backends.mps, 'is_available') and torch.backends.mps.is_available(): # type: ignore[attr-defined]
available_devices.append(('mps', 1))
logger.debug("Apple MPS available")
# CPU always available
available_devices.append(('cpu', 1))
# Select best available device based on priority
for device_type in device_priority:
for available_type, count in available_devices:
if available_type == device_type:
return {'device': device_type, 'count': count}
return {'device': 'cpu', 'count': 1} # Fallback
def validate(self):
"""Comprehensive configuration validation"""
errors = []
# Check required paths
if self.train_data_path and not os.path.exists(self.train_data_path):
errors.append(f"Training data not found: {self.train_data_path}")
if self.eval_data_path and not os.path.exists(self.eval_data_path):
errors.append(f"Evaluation data not found: {self.eval_data_path}")
# Validate numeric parameters
if self.learning_rate is None or self.learning_rate <= 0:
errors.append("Learning rate must be positive and not None")
if self.num_train_epochs is None or self.num_train_epochs <= 0:
errors.append("Number of epochs must be positive and not None")
if self.per_device_train_batch_size is None or self.per_device_train_batch_size <= 0:
errors.append("Training batch size must be positive and not None")
if self.gradient_accumulation_steps is None or self.gradient_accumulation_steps <= 0:
errors.append("Gradient accumulation steps must be positive and not None")
if self.warmup_ratio is None or self.warmup_ratio < 0 or self.warmup_ratio > 1:
errors.append("Warmup ratio must be between 0 and 1 and not None")
# Validate device
if self.device not in ['auto', 'cpu', 'cuda', 'npu', 'mps']:
errors.append(f"Invalid device: {self.device}")
# Check sequence lengths
if self.max_seq_length is None or self.max_seq_length <= 0:
errors.append("Maximum sequence length must be positive and not None")
if self.query_max_length is None or self.query_max_length <= 0:
errors.append("Maximum query length must be positive and not None")
if self.passage_max_length is None or self.passage_max_length <= 0:
errors.append("Maximum passage length must be positive and not None")
if errors:
raise ValueError(f"Configuration validation failed:\n" + "\n".join(f" - {error}" for error in errors))
logger.info("Configuration validation passed")
return True
def to_dict(self) -> Dict[str, Any]:
"""Convert config to dictionary with type handling"""
config_dict = {}
for key, value in self.__dict__.items():
if isinstance(value, (str, int, float, bool, type(None))):
config_dict[key] = value
elif isinstance(value, (list, tuple)):
config_dict[key] = list(value)
elif isinstance(value, dict):
config_dict[key] = dict(value)
else:
config_dict[key] = str(value)
return config_dict
def save(self, save_path: str):
"""Save configuration to file with error handling"""
try:
os.makedirs(os.path.dirname(save_path), exist_ok=True)
import json
with open(save_path, 'w', encoding='utf-8') as f:
json.dump(self.to_dict(), f, indent=2, ensure_ascii=False)
logger.info(f"Configuration saved to {save_path}")
except Exception as e:
logger.error(f"Failed to save configuration: {e}")
raise
@classmethod
def load(cls, load_path: str):
"""Load configuration from file with validation"""
try:
import json
with open(load_path, 'r', encoding='utf-8') as f:
config_dict = json.load(f)
# Filter only valid field names
import inspect
sig = inspect.signature(cls)
valid_keys = set(sig.parameters.keys())
filtered_dict = {k: v for k, v in config_dict.items() if k in valid_keys}
config = cls(**filtered_dict)
config.validate() # Validate after loading
logger.info(f"Configuration loaded from {load_path}")
return config
except Exception as e:
logger.error(f"Failed to load configuration from {load_path}: {e}")
raise
@classmethod
def from_dict(cls, config_dict: Dict[str, Any]):
"""Create config from dictionary with validation"""
try:
# Filter only valid field names
import inspect
sig = inspect.signature(cls)
valid_keys = set(sig.parameters.keys())
filtered_dict = {k: v for k, v in config_dict.items() if k in valid_keys}
config = cls(**filtered_dict)
config.validate() # Validate after creation
return config
except Exception as e:
logger.error(f"Failed to create configuration from dictionary: {e}")
raise
def update_from_args(self, args):
"""Update config from command line arguments with validation"""
try:
updated_fields = []
for key, value in vars(args).items():
if hasattr(self, key) and value is not None:
old_value = getattr(self, key)
setattr(self, key, value)
updated_fields.append(f"{key}: {old_value} -> {value}")
if updated_fields:
logger.info("Updated configuration from command line arguments:")
for field in updated_fields:
logger.info(f" {field}")
# Validate after updates
self.validate()
except Exception as e:
logger.error(f"Failed to update configuration from arguments: {e}")
raise
def get_effective_batch_size(self) -> int:
"""Calculate effective batch size considering gradient accumulation"""
return self.per_device_train_batch_size * self.gradient_accumulation_steps * max(1, self.n_gpu)
def get_training_steps(self, dataset_size: int) -> int:
"""Calculate total training steps"""
steps_per_epoch = dataset_size // self.get_effective_batch_size()
return steps_per_epoch * self.num_train_epochs
def print_effective_config(self):
"""Print the effective value and its source for each parameter."""
print("Parameter | Value | Source")
print("----------|-------|-------")
for field in self.__dataclass_fields__:
value = getattr(self, field)
# For now, just print the value; source tracking can be added if needed
print(f"{field} | {value} | (see docstring for precedence)")
def __str__(self) -> str:
"""String representation of configuration"""
return f"BaseConfig(device={self.device}, batch_size={self.per_device_train_batch_size}, lr={self.learning_rate})"
def __repr__(self) -> str:
return self.__str__()
def get_device(self):
"""Return the device string for training (cuda, npu, mps, cpu)"""
if torch.cuda.is_available():
return "cuda"
npu = getattr(torch, "npu", None)
if npu is not None and hasattr(npu, "is_available") and npu.is_available():
return "npu"
backends = getattr(torch, "backends", None)
if backends is not None and hasattr(backends, "mps") and backends.mps.is_available():
return "mps"
return "cpu"
def is_npu(self):
"""Return True if NPU is available and selected as device."""
npu = getattr(torch, "npu", None)
if npu is not None and hasattr(npu, "is_available") and npu.is_available():
device = self.get_device() if hasattr(self, 'get_device') else "cpu"
return device == "npu"
return False