import os import torch from dataclasses import dataclass, field from typing import Optional, List, Dict, Any import logging logger = logging.getLogger(__name__) @dataclass class BaseConfig: """ Base configuration for all trainers and models. Configuration Precedence Order: 1. Command-line arguments (highest priority) 2. Model-specific config sections ([m3], [reranker]) 3. [hardware] section (for hardware-specific overrides) 4. [training] section (general training defaults) 5. Dataclass field defaults (lowest priority) Example: If batch_size is defined in multiple places: - CLI arg --batch_size 32 (wins) - [m3] batch_size = 16 - [hardware] per_device_train_batch_size = 8 - [training] default_batch_size = 4 - Dataclass default = 2 Note: Model-specific configs ([m3], [reranker]) take precedence over [hardware] for their respective models to allow fine-grained control. """ temperature: float = 1.0 # Softmax temperature for loss functions use_in_batch_negatives: bool = False # Use in-batch negatives for contrastive loss negatives_cross_device: bool = False # Use negatives across devices (DDP) label_smoothing: float = 0.0 # Label smoothing for classification losses alpha: float = 1.0 # Weight for auxiliary losses gradient_accumulation_steps: int = 1 # Steps to accumulate gradients num_train_epochs: int = 3 # Number of training epochs eval_steps: int = 1 # Evaluate every N epochs save_steps: int = 1 # Save checkpoint every N epochs warmup_steps: int = 0 # Number of warmup steps for scheduler warmup_ratio: float = 0.0 # Warmup ratio for scheduler metric_for_best_model: str = "accuracy" # Metric to select best model checkpoint_dir: str = "./checkpoints" # Directory for saving checkpoints resume_from_checkpoint: str = "" # Path to resume checkpoint max_grad_norm: float = 1.0 # Max gradient norm for clipping use_amp: bool = False # Use automatic mixed precision # Model configuration model_name_or_path: str = "BAAI/bge-m3" cache_dir: Optional[str] = None # Data configuration train_data_path: str = "./data/train.jsonl" eval_data_path: Optional[str] = "./data/eval.jsonl" max_seq_length: int = 512 query_max_length: int = 64 passage_max_length: int = 512 # Training configuration output_dir: str = "./output" learning_rate: float = 1e-5 weight_decay: float = 0.01 adam_epsilon: float = 1e-8 adam_beta1: float = 0.9 adam_beta2: float = 0.999 # Advanced training configuration fp16: bool = True bf16: bool = False gradient_checkpointing: bool = True deepspeed: Optional[str] = None local_rank: int = -1 dataloader_num_workers: int = 4 dataloader_pin_memory: bool = True # Batch sizes per_device_train_batch_size: int = 4 per_device_eval_batch_size: int = 8 # Logging and saving logging_dir: str = "./logs" # For testing only, in production, set to 100 logging_steps: int = 100 save_total_limit: int = 3 load_best_model_at_end: bool = True greater_is_better: bool = False evaluation_strategy: str = "steps" # Checkpoint and resume overwrite_output_dir: bool = False # Seed seed: int = 42 # Platform compatibility device: str = "auto" n_gpu: int = 0 def __post_init__(self): """Post-initialization with robust config loading and device setup Notes: - [hardware] section overrides [training] for overlapping parameters. - Model-specific configs ([m3], [reranker]) override [training]/[hardware] for their models. """ # Load configuration from centralized config system config = self._get_config_safely() # Apply configuration values with fallbacks self._apply_config_values(config) # Setup device detection self._setup_device() def _get_config_safely(self): """Safely get configuration with multiple fallback strategies""" try: # Try absolute import from utils.config_loader import get_config return get_config() except ImportError: try: # Try absolute import (for scripts) from bge_finetune.utils.config_loader import get_config # type: ignore[import] return get_config() except ImportError: try: # Try direct import (when run as script from root) from utils.config_loader import get_config return get_config() except ImportError: logger.error("Could not import config_loader, using fallback configuration. Please check your environment and config loader path.") return self._create_fallback_config() def _create_fallback_config(self): """Create a minimal fallback configuration object""" class FallbackConfig: def __init__(self): self._config = { 'training': { 'default_batch_size': 4, 'default_learning_rate': 1e-5, 'default_num_epochs': 3, 'default_warmup_ratio': 0.1, 'default_weight_decay': 0.01, 'gradient_accumulation_steps': 4 }, 'model_paths': { 'cache_dir': './cache/data' }, 'data': { 'max_seq_length': 512, 'max_query_length': 64, 'max_passage_length': 512, 'random_seed': 42 }, 'hardware': { 'dataloader_num_workers': 4, 'dataloader_pin_memory': True, 'per_device_train_batch_size': 4, 'per_device_eval_batch_size': 8 }, 'optimization': { 'fp16': True, 'bf16': False, 'gradient_checkpointing': True, 'max_grad_norm': 1.0, 'adam_epsilon': 1e-8, 'adam_beta1': 0.9, 'adam_beta2': 0.999 }, 'checkpoint': { 'save_steps': 1000, 'save_total_limit': 3, 'load_best_model_at_end': True, 'metric_for_best_model': 'eval_loss', 'greater_is_better': False }, 'logging': { 'log_dir': './logs' } } def get(self, key, default=None): keys = key.split('.') value = self._config try: for k in keys: value = value[k] return value except (KeyError, TypeError): return default def get_section(self, section): return self._config.get(section, {}) return FallbackConfig() def _apply_config_values(self, config): """ Apply configuration values with proper precedence order. Precedence (lowest to highest): 1. Dataclass field defaults (already set) 2. [training] section values 3. [hardware] section values (override training) 4. Model-specific sections (handled in subclasses) 5. Command-line arguments (handled by training scripts) This method handles steps 2-3. Model-specific configs are applied in subclass __post_init__ methods AFTER this base method runs. """ try: # === STEP 2: Apply [training] section (general defaults) === training_config = config.get_section('training') if training_config: # Only update if value exists in config if 'default_num_epochs' in training_config: old_value = self.num_train_epochs self.num_train_epochs = training_config['default_num_epochs'] if old_value != self.num_train_epochs: logger.debug(f"[training] num_train_epochs: {old_value} -> {self.num_train_epochs}") if 'default_batch_size' in training_config: self.per_device_train_batch_size = training_config['default_batch_size'] logger.debug(f"[training] batch_size set to {self.per_device_train_batch_size}") if 'default_learning_rate' in training_config: self.learning_rate = training_config['default_learning_rate'] if 'default_warmup_ratio' in training_config: self.warmup_ratio = training_config['default_warmup_ratio'] if 'gradient_accumulation_steps' in training_config: self.gradient_accumulation_steps = training_config['gradient_accumulation_steps'] if 'default_weight_decay' in training_config: self.weight_decay = training_config['default_weight_decay'] # === STEP 3: Apply [hardware] section (overrides training) === hardware_config = config.get_section('hardware') if hardware_config: # Hardware settings take precedence over training defaults if 'dataloader_num_workers' in hardware_config: self.dataloader_num_workers = hardware_config['dataloader_num_workers'] if 'dataloader_pin_memory' in hardware_config: self.dataloader_pin_memory = hardware_config['dataloader_pin_memory'] # Batch sizes from hardware override training defaults if 'per_device_train_batch_size' in hardware_config: old_value = self.per_device_train_batch_size self.per_device_train_batch_size = hardware_config['per_device_train_batch_size'] if old_value != self.per_device_train_batch_size: logger.info(f"[hardware] overrides batch_size: {old_value} -> {self.per_device_train_batch_size}") if 'per_device_eval_batch_size' in hardware_config: self.per_device_eval_batch_size = hardware_config['per_device_eval_batch_size'] # === Apply other configuration sections === # Model paths if self.cache_dir is None: cache_dir = config.get('model_paths.cache_dir') if cache_dir is not None: self.cache_dir = cache_dir logger.debug(f"cache_dir set from config: {self.cache_dir}") # Data configuration data_config = config.get_section('data') if data_config: if 'max_seq_length' in data_config: self.max_seq_length = data_config['max_seq_length'] if 'max_query_length' in data_config: self.query_max_length = data_config['max_query_length'] if 'max_passage_length' in data_config: self.passage_max_length = data_config['max_passage_length'] if 'random_seed' in data_config: self.seed = data_config['random_seed'] # Optimization configuration opt_config = config.get_section('optimization') if opt_config: if 'fp16' in opt_config: self.fp16 = opt_config['fp16'] if 'bf16' in opt_config: self.bf16 = opt_config['bf16'] if 'gradient_checkpointing' in opt_config: self.gradient_checkpointing = opt_config['gradient_checkpointing'] if 'max_grad_norm' in opt_config: self.max_grad_norm = opt_config['max_grad_norm'] if 'adam_epsilon' in opt_config: self.adam_epsilon = opt_config['adam_epsilon'] if 'adam_beta1' in opt_config: self.adam_beta1 = opt_config['adam_beta1'] if 'adam_beta2' in opt_config: self.adam_beta2 = opt_config['adam_beta2'] # Checkpoint configuration checkpoint_config = config.get_section('checkpoint') if checkpoint_config: if 'save_steps' in checkpoint_config: self.save_steps = checkpoint_config['save_steps'] if 'save_total_limit' in checkpoint_config: self.save_total_limit = checkpoint_config['save_total_limit'] if 'load_best_model_at_end' in checkpoint_config: self.load_best_model_at_end = checkpoint_config['load_best_model_at_end'] if 'metric_for_best_model' in checkpoint_config: self.metric_for_best_model = checkpoint_config['metric_for_best_model'] if 'greater_is_better' in checkpoint_config: self.greater_is_better = checkpoint_config['greater_is_better'] # Logging configuration logging_config = config.get_section('logging') if logging_config: if 'log_dir' in logging_config: self.logging_dir = logging_config['log_dir'] except Exception as e: logger.warning(f"Error applying configuration values: {e}") def _setup_device(self): """Enhanced device detection with Ascend NPU support""" if self.device == "auto": device_info = self._detect_available_devices() self.device = device_info['device'] self.n_gpu = device_info['count'] logger.info(f"Auto-detected device: {self.device} (count: {self.n_gpu})") else: # Manual device specification if self.device == "cuda" and torch.cuda.is_available(): self.n_gpu = torch.cuda.device_count() elif self.device == "npu": try: import torch_npu # type: ignore[import] npu = getattr(torch, 'npu', None) if npu is not None and npu.is_available(): self.n_gpu = npu.device_count() else: logger.warning("NPU requested but not available, falling back to CPU") self.device = "cpu" self.n_gpu = 0 except ImportError: logger.warning("NPU requested but torch_npu not installed, falling back to CPU") self.device = "cpu" self.n_gpu = 0 else: self.n_gpu = 0 def _detect_available_devices(self) -> Dict[str, Any]: """Comprehensive device detection with priority ordering""" device_priority = ['npu', 'cuda', 'mps', 'cpu'] available_devices = [] # Check Ascend NPU try: import torch_npu # type: ignore[import] npu = getattr(torch, 'npu', None) if npu is not None and npu.is_available(): count = npu.device_count() available_devices.append(('npu', count)) logger.debug(f"Ascend NPU available: {count} devices") except ImportError: logger.debug("Ascend NPU library not available") # Check CUDA if torch.cuda.is_available(): count = torch.cuda.device_count() available_devices.append(('cuda', count)) logger.debug(f"CUDA available: {count} devices") # Check MPS (Apple Silicon) if hasattr(torch, 'backends') and hasattr(torch.backends, 'mps') and hasattr(torch.backends.mps, 'is_available') and torch.backends.mps.is_available(): # type: ignore[attr-defined] available_devices.append(('mps', 1)) logger.debug("Apple MPS available") # CPU always available available_devices.append(('cpu', 1)) # Select best available device based on priority for device_type in device_priority: for available_type, count in available_devices: if available_type == device_type: return {'device': device_type, 'count': count} return {'device': 'cpu', 'count': 1} # Fallback def validate(self): """Comprehensive configuration validation""" errors = [] # Check required paths if self.train_data_path and not os.path.exists(self.train_data_path): errors.append(f"Training data not found: {self.train_data_path}") if self.eval_data_path and not os.path.exists(self.eval_data_path): errors.append(f"Evaluation data not found: {self.eval_data_path}") # Validate numeric parameters if self.learning_rate is None or self.learning_rate <= 0: errors.append("Learning rate must be positive and not None") if self.num_train_epochs is None or self.num_train_epochs <= 0: errors.append("Number of epochs must be positive and not None") if self.per_device_train_batch_size is None or self.per_device_train_batch_size <= 0: errors.append("Training batch size must be positive and not None") if self.gradient_accumulation_steps is None or self.gradient_accumulation_steps <= 0: errors.append("Gradient accumulation steps must be positive and not None") if self.warmup_ratio is None or self.warmup_ratio < 0 or self.warmup_ratio > 1: errors.append("Warmup ratio must be between 0 and 1 and not None") # Validate device if self.device not in ['auto', 'cpu', 'cuda', 'npu', 'mps']: errors.append(f"Invalid device: {self.device}") # Check sequence lengths if self.max_seq_length is None or self.max_seq_length <= 0: errors.append("Maximum sequence length must be positive and not None") if self.query_max_length is None or self.query_max_length <= 0: errors.append("Maximum query length must be positive and not None") if self.passage_max_length is None or self.passage_max_length <= 0: errors.append("Maximum passage length must be positive and not None") if errors: raise ValueError(f"Configuration validation failed:\n" + "\n".join(f" - {error}" for error in errors)) logger.info("Configuration validation passed") return True def to_dict(self) -> Dict[str, Any]: """Convert config to dictionary with type handling""" config_dict = {} for key, value in self.__dict__.items(): if isinstance(value, (str, int, float, bool, type(None))): config_dict[key] = value elif isinstance(value, (list, tuple)): config_dict[key] = list(value) elif isinstance(value, dict): config_dict[key] = dict(value) else: config_dict[key] = str(value) return config_dict def save(self, save_path: str): """Save configuration to file with error handling""" try: os.makedirs(os.path.dirname(save_path), exist_ok=True) import json with open(save_path, 'w', encoding='utf-8') as f: json.dump(self.to_dict(), f, indent=2, ensure_ascii=False) logger.info(f"Configuration saved to {save_path}") except Exception as e: logger.error(f"Failed to save configuration: {e}") raise @classmethod def load(cls, load_path: str): """Load configuration from file with validation""" try: import json with open(load_path, 'r', encoding='utf-8') as f: config_dict = json.load(f) # Filter only valid field names import inspect sig = inspect.signature(cls) valid_keys = set(sig.parameters.keys()) filtered_dict = {k: v for k, v in config_dict.items() if k in valid_keys} config = cls(**filtered_dict) config.validate() # Validate after loading logger.info(f"Configuration loaded from {load_path}") return config except Exception as e: logger.error(f"Failed to load configuration from {load_path}: {e}") raise @classmethod def from_dict(cls, config_dict: Dict[str, Any]): """Create config from dictionary with validation""" try: # Filter only valid field names import inspect sig = inspect.signature(cls) valid_keys = set(sig.parameters.keys()) filtered_dict = {k: v for k, v in config_dict.items() if k in valid_keys} config = cls(**filtered_dict) config.validate() # Validate after creation return config except Exception as e: logger.error(f"Failed to create configuration from dictionary: {e}") raise def update_from_args(self, args): """Update config from command line arguments with validation""" try: updated_fields = [] for key, value in vars(args).items(): if hasattr(self, key) and value is not None: old_value = getattr(self, key) setattr(self, key, value) updated_fields.append(f"{key}: {old_value} -> {value}") if updated_fields: logger.info("Updated configuration from command line arguments:") for field in updated_fields: logger.info(f" {field}") # Validate after updates self.validate() except Exception as e: logger.error(f"Failed to update configuration from arguments: {e}") raise def get_effective_batch_size(self) -> int: """Calculate effective batch size considering gradient accumulation""" return self.per_device_train_batch_size * self.gradient_accumulation_steps * max(1, self.n_gpu) def get_training_steps(self, dataset_size: int) -> int: """Calculate total training steps""" steps_per_epoch = dataset_size // self.get_effective_batch_size() return steps_per_epoch * self.num_train_epochs def print_effective_config(self): """Print the effective value and its source for each parameter.""" print("Parameter | Value | Source") print("----------|-------|-------") for field in self.__dataclass_fields__: value = getattr(self, field) # For now, just print the value; source tracking can be added if needed print(f"{field} | {value} | (see docstring for precedence)") def __str__(self) -> str: """String representation of configuration""" return f"BaseConfig(device={self.device}, batch_size={self.per_device_train_batch_size}, lr={self.learning_rate})" def __repr__(self) -> str: return self.__str__() def get_device(self): """Return the device string for training (cuda, npu, mps, cpu)""" if torch.cuda.is_available(): return "cuda" npu = getattr(torch, "npu", None) if npu is not None and hasattr(npu, "is_available") and npu.is_available(): return "npu" backends = getattr(torch, "backends", None) if backends is not None and hasattr(backends, "mps") and backends.mps.is_available(): return "mps" return "cpu" def is_npu(self): """Return True if NPU is available and selected as device.""" npu = getattr(torch, "npu", None) if npu is not None and hasattr(npu, "is_available") and npu.is_available(): device = self.get_device() if hasattr(self, 'get_device') else "cpu" return device == "npu" return False