""" Centralized configuration loader for BGE fine-tuning project """ import os import toml import logging from typing import Dict, Any, Optional, Tuple from pathlib import Path logger = logging.getLogger(__name__) def resolve_model_path( model_key: str, config_instance: 'ConfigLoader', model_name_or_path: Optional[str] = None, cache_dir: Optional[str] = None ) -> Tuple[str, str]: """ Resolve model path with automatic local/remote fallback Args: model_key: Key in model_paths config (e.g., 'bge_m3', 'bge_reranker') config_instance: ConfigLoader instance model_name_or_path: Optional override from CLI cache_dir: Optional cache directory Returns: Tuple of (resolved_model_path, resolved_cache_dir) """ # Default remote model names for each model type remote_models = { 'bge_m3': 'BAAI/bge-m3', 'bge_reranker': 'BAAI/bge-reranker-base', } # Get model source (huggingface or modelscope) model_source = config_instance.get('model_paths', 'source', 'huggingface') # If ModelScope is used, adjust model names if model_source == 'modelscope': remote_models = { 'bge_m3': 'damo/nlp_gte_sentence-embedding_chinese-base', 'bge_reranker': 'damo/nlp_corom_passage-ranking_english-base', } # 1. Check CLI override first (only if explicitly provided) if model_name_or_path is not None: resolved_model_path = model_name_or_path logger.info(f"[CLI OVERRIDE] Using model path: {resolved_model_path}") else: # 2. Check local path from config local_path_raw = config_instance.get('model_paths', model_key, None) # Handle list/string conversion if isinstance(local_path_raw, list): local_path = local_path_raw[0] if local_path_raw else None else: local_path = local_path_raw if local_path and isinstance(local_path, str) and os.path.exists(local_path): # Check if it's a valid model directory if _is_valid_model_directory(local_path): resolved_model_path = local_path logger.info(f"[LOCAL] Using local model: {resolved_model_path}") else: logger.warning(f"[LOCAL] Invalid model directory: {local_path}, falling back to remote") resolved_model_path = remote_models.get(model_key, f'BAAI/{model_key}') else: # 3. Fall back to remote model resolved_model_path = remote_models.get(model_key, f'BAAI/{model_key}') logger.info(f"[REMOTE] Using remote model: {resolved_model_path}") # If local path was configured but doesn't exist, we'll download to that location if local_path and isinstance(local_path, str): logger.info(f"[DOWNLOAD] Will download to configured path: {local_path}") # Create directory if it doesn't exist os.makedirs(local_path, exist_ok=True) # Resolve cache directory if cache_dir is not None: resolved_cache_dir = cache_dir else: cache_dir_raw = config_instance.get('model_paths', 'cache_dir', './cache/data') if isinstance(cache_dir_raw, list): resolved_cache_dir = cache_dir_raw[0] if cache_dir_raw else './cache/data' else: resolved_cache_dir = cache_dir_raw if cache_dir_raw else './cache/data' # Ensure cache directory exists if isinstance(resolved_cache_dir, str): os.makedirs(resolved_cache_dir, exist_ok=True) return resolved_model_path, resolved_cache_dir def _is_valid_model_directory(path: str) -> bool: """Check if a directory contains a valid model""" required_files = ['config.json'] optional_files = ['tokenizer_config.json', 'vocab.txt', 'tokenizer.json'] model_files = ['pytorch_model.bin', 'model.safetensors', 'pytorch_model.bin.index.json'] path_obj = Path(path) if not path_obj.exists() or not path_obj.is_dir(): return False # Check for required files for file in required_files: if not (path_obj / file).exists(): return False # Check for at least one model file has_model_file = any((path_obj / file).exists() for file in model_files) if not has_model_file: return False # Check for at least one tokenizer file (for embedding models) has_tokenizer_file = any((path_obj / file).exists() for file in optional_files) if not has_tokenizer_file: logger.warning(f"Model directory {path} missing tokenizer files, but proceeding anyway") return True def download_model_if_needed( model_name_or_path: str, local_path: Optional[str] = None, cache_dir: Optional[str] = None, model_source: str = 'huggingface' ) -> str: """ Download model if it's a remote identifier and local path is specified Args: model_name_or_path: Model identifier or path local_path: Target local path for download cache_dir: Cache directory model_source: 'huggingface' or 'modelscope' Returns: Final model path (local if downloaded, original if already local) """ # If it's already a local path, return as is if os.path.exists(model_name_or_path): return model_name_or_path # If no local path specified, return original (will be handled by transformers) if not local_path: return model_name_or_path # Check if local path already exists and is valid if os.path.exists(local_path) and _is_valid_model_directory(local_path): logger.info(f"Model already exists locally: {local_path}") return local_path # Download the model try: logger.info(f"Downloading model {model_name_or_path} to {local_path}") if model_source == 'huggingface': from transformers import AutoModel, AutoTokenizer # Download model and tokenizer model = AutoModel.from_pretrained( model_name_or_path, cache_dir=cache_dir, trust_remote_code=True ) tokenizer = AutoTokenizer.from_pretrained( model_name_or_path, cache_dir=cache_dir, trust_remote_code=True ) # Save to local path os.makedirs(local_path, exist_ok=True) model.save_pretrained(local_path) tokenizer.save_pretrained(local_path) logger.info(f"Model downloaded successfully to {local_path}") return local_path elif model_source == 'modelscope': try: from modelscope import snapshot_download # type: ignore # Download using ModelScope downloaded_path = snapshot_download( model_name_or_path, cache_dir=cache_dir, local_dir=local_path ) logger.info(f"Model downloaded successfully from ModelScope to {downloaded_path}") return downloaded_path except ImportError: logger.warning("ModelScope not available, falling back to HuggingFace") return download_model_if_needed( model_name_or_path, local_path, cache_dir, 'huggingface' ) except Exception as e: logger.error(f"Failed to download model {model_name_or_path}: {e}") logger.info("Returning original model path - transformers will handle caching") # Fallback return return model_name_or_path class ConfigLoader: """Centralized configuration management for BGE fine-tuning Parameter Precedence (highest to lowest): 1. Command-line arguments (CLI) 2. Model-specific config ([m3], [reranker]) 3. [hardware] section in config.toml 4. [training] section in config.toml 5. Hardcoded defaults in code """ _instance = None _config = None def __new__(cls): if cls._instance is None: cls._instance = super(ConfigLoader, cls).__new__(cls) return cls._instance def __init__(self): if self._config is None: self._load_config() def _find_config_file(self) -> str: """Find config.toml file in project hierarchy""" # Check environment variable FIRST (highest priority) if 'BGE_CONFIG_PATH' in os.environ: config_path = Path(os.environ['BGE_CONFIG_PATH']) if config_path.exists(): return str(config_path) else: logger.warning(f"BGE_CONFIG_PATH set to {config_path} but file doesn't exist, falling back to default") current_dir = Path(__file__).parent # Search up the directory tree for config.toml for _ in range(5): # Limit search depth config_path = current_dir / "config.toml" if config_path.exists(): return str(config_path) current_dir = current_dir.parent # Default location return os.path.join(os.getcwd(), "config.toml") def _load_config(self): """Load configuration from TOML file""" config_path = self._find_config_file() try: if os.path.exists(config_path): with open(config_path, 'r', encoding='utf-8') as f: loaded = toml.load(f) if loaded is None: loaded = {} self._config = dict(loaded) logger.info(f"Loaded configuration from {config_path}") else: logger.warning(f"Config file not found at {config_path}, using defaults") self._config = self._get_default_config() self._save_default_config(config_path) logger.warning("Default config created. Please check your configuration file.") except Exception as e: logger.error(f"Error loading config: {e}") self._config = self._get_default_config() # Validation for unset/placeholder values self._validate_config() def _validate_config(self): """Validate config for unset/placeholder values""" api_key = self.get('api', 'api_key', None) if not api_key or api_key == 'your-api-key-here': logger.warning("API key is unset or placeholder. Set it in config.toml or via the BGE_API_KEY environment variable.") def _get_default_config(self) -> Dict[str, Any]: """Get default configuration""" return { 'api': { 'base_url': 'https://api.deepseek.com/v1', 'api_key': 'your-api-key-here', 'model': 'deepseek-chat', 'max_retries': 3, 'timeout': 30, 'temperature': 0.7 }, 'training': { 'default_batch_size': 4, 'default_learning_rate': 1e-5, 'default_num_epochs': 3, 'default_warmup_ratio': 0.1, 'default_weight_decay': 0.01, 'gradient_accumulation_steps': 4 }, 'model_paths': { 'bge_m3': './models/bge-m3', 'bge_reranker': './models/bge-reranker-base', 'cache_dir': './cache/data' }, 'data': { 'max_query_length': 64, 'max_passage_length': 512, 'max_seq_length': 512, 'validation_split': 0.1, 'random_seed': 42 }, 'augmentation': { 'enable_query_augmentation': False, 'enable_passage_augmentation': False, 'augmentation_methods': ['paraphrase', 'back_translation'], 'num_augmentations_per_sample': 2, 'augmentation_temperature': 0.8 }, 'hard_negative_mining': { 'enable_mining': True, 'num_hard_negatives': 15, 'negative_range_min': 10, 'negative_range_max': 100, 'margin': 0.1, 'index_refresh_interval': 1000, 'index_type': 'IVF', 'nlist': 100 }, 'evaluation': { 'eval_batch_size': 32, 'k_values': [1, 3, 5, 10, 20, 50, 100], 'metrics': ['recall', 'precision', 'map', 'mrr', 'ndcg'], 'save_predictions': True, 'predictions_dir': './output/predictions' }, 'logging': { 'log_level': 'INFO', 'log_to_file': True, 'log_dir': './logs', 'tensorboard_dir': './logs/tensorboard', 'wandb_project': 'bge-finetuning', 'wandb_entity': '', 'use_wandb': False }, 'distributed': { 'backend': 'nccl', 'init_method': 'env://', 'find_unused_parameters': True }, 'optimization': { 'gradient_checkpointing': True, 'fp16': True, 'bf16': False, 'fp16_opt_level': 'O1', 'max_grad_norm': 1.0, 'adam_epsilon': 1e-8, 'adam_beta1': 0.9, 'adam_beta2': 0.999 }, 'hardware': { 'cuda_visible_devices': '0,1,2,3', 'dataloader_num_workers': 4, 'dataloader_pin_memory': True, 'per_device_train_batch_size': 4, 'per_device_eval_batch_size': 8 }, 'checkpoint': { 'save_strategy': 'steps', 'save_steps': 1000, 'save_total_limit': 3, 'load_best_model_at_end': True, 'metric_for_best_model': 'eval_recall@10', 'greater_is_better': True, 'resume_from_checkpoint': '' }, 'export': { 'export_format': ['pytorch', 'onnx'], 'optimize_for_inference': True, 'quantize': False, 'quantization_bits': 8 } } def _save_default_config(self, config_path: str): """Save default configuration to file""" try: os.makedirs(os.path.dirname(config_path), exist_ok=True) with open(config_path, 'w', encoding='utf-8') as f: config_to_save = self._config if isinstance(self._config, dict) else {} toml.dump(config_to_save, f) logger.info(f"Created default config at {config_path}") except Exception as e: logger.error(f"Error saving default config: {e}") def get(self, section: str, key: str, default=None): """Get a config value, logging its source (CLI, config, default).""" # CLI overrides are handled outside this loader; here we check config and default value = default source = 'default' if isinstance(self._config, dict) and section in self._config and key in self._config[section]: value = self._config[section][key] source = 'config.toml' # Robust list/tuple conversion if isinstance(value, str) and (',' in value or ';' in value): try: value = [v.strip() for v in value.replace(';', ',').split(',') if v.strip()] source += ' (parsed as list)' except Exception as e: logging.warning(f"Failed to parse list for {section}.{key}: {e}") log_func = logging.info if source != 'default' else logging.warning log_func(f"Config param {section}.{key} = {value} (source: {source})") return value def get_section(self, section: str) -> Dict[str, Any]: """Get entire configuration section""" if isinstance(self._config, dict): return self._config.get(section, {}) return {} def update(self, key: str, value: Any): """Update configuration value. Logs a warning if overriding an existing value.""" if self._config is None: self._config = {} keys = key.split('.') config = self._config # Navigate to the parent of the final key for k in keys[:-1]: if k not in config: config[k] = {} config = config[k] # Log if overriding old_value = config.get(keys[-1], None) if old_value is not None and old_value != value: logger.warning(f"Parameter '{key}' overridden: {old_value} -> {value} (higher-precedence source)") config[keys[-1]] = value def reload(self): """Reload configuration from file""" self._config = None self._load_config() # Global instance config = ConfigLoader() def get_config() -> ConfigLoader: """Get global configuration instance""" return config def load_config_for_training( config_path: Optional[str] = None, override_params: Optional[Dict[str, Any]] = None ) -> Dict[str, Any]: """ Load configuration specifically for training scripts Args: config_path: Optional path to config file override_params: Parameters to override from command line Returns: Dictionary with all configuration parameters """ global config # Reload with specific path if provided if config_path: old_path = os.environ.get('BGE_CONFIG_PATH') os.environ['BGE_CONFIG_PATH'] = config_path config.reload() if old_path: os.environ['BGE_CONFIG_PATH'] = old_path else: del os.environ['BGE_CONFIG_PATH'] # Get all configuration training_config = { 'api': config.get_section('api'), 'training': config.get_section('training'), 'model_paths': config.get_section('model_paths'), 'data': config.get_section('data'), 'augmentation': config.get_section('augmentation'), 'hard_negative_mining': config.get_section('hard_negative_mining'), 'evaluation': config.get_section('evaluation'), 'logging': config.get_section('logging'), 'distributed': config.get_section('distributed'), 'optimization': config.get_section('optimization'), 'hardware': config.get_section('hardware'), 'checkpoint': config.get_section('checkpoint'), 'export': config.get_section('export'), 'm3': config.get_section('m3'), 'reranker': config.get_section('reranker'), } # Apply overrides if provided if override_params: for key, value in override_params.items(): config.update(key, value) # Also update the returned dict keys = key.split('.') current = training_config for k in keys[:-1]: if k not in current: current[k] = {} current = current[k] current[keys[-1]] = value return training_config