bge_finetune/utils/config_loader.py
2025-07-22 16:55:25 +08:00

519 lines
19 KiB
Python

"""
Centralized configuration loader for BGE fine-tuning project
"""
import os
import toml
import logging
from typing import Dict, Any, Optional, Tuple
from pathlib import Path
logger = logging.getLogger(__name__)
def resolve_model_path(
model_key: str,
config_instance: 'ConfigLoader',
model_name_or_path: Optional[str] = None,
cache_dir: Optional[str] = None
) -> Tuple[str, str]:
"""
Resolve model path with automatic local/remote fallback
Args:
model_key: Key in model_paths config (e.g., 'bge_m3', 'bge_reranker')
config_instance: ConfigLoader instance
model_name_or_path: Optional override from CLI
cache_dir: Optional cache directory
Returns:
Tuple of (resolved_model_path, resolved_cache_dir)
"""
# Default remote model names for each model type
remote_models = {
'bge_m3': 'BAAI/bge-m3',
'bge_reranker': 'BAAI/bge-reranker-base',
}
# Get model source (huggingface or modelscope)
model_source = config_instance.get('model_paths', 'source', 'huggingface')
# If ModelScope is used, adjust model names
if model_source == 'modelscope':
remote_models = {
'bge_m3': 'damo/nlp_gte_sentence-embedding_chinese-base',
'bge_reranker': 'damo/nlp_corom_passage-ranking_english-base',
}
# 1. Check CLI override first (only if explicitly provided)
if model_name_or_path is not None:
resolved_model_path = model_name_or_path
logger.info(f"[CLI OVERRIDE] Using model path: {resolved_model_path}")
else:
# 2. Check local path from config
local_path_raw = config_instance.get('model_paths', model_key, None)
# Handle list/string conversion
if isinstance(local_path_raw, list):
local_path = local_path_raw[0] if local_path_raw else None
else:
local_path = local_path_raw
if local_path and isinstance(local_path, str) and os.path.exists(local_path):
# Check if it's a valid model directory
if _is_valid_model_directory(local_path):
resolved_model_path = local_path
logger.info(f"[LOCAL] Using local model: {resolved_model_path}")
else:
logger.warning(f"[LOCAL] Invalid model directory: {local_path}, falling back to remote")
resolved_model_path = remote_models.get(model_key, f'BAAI/{model_key}')
else:
# 3. Fall back to remote model
resolved_model_path = remote_models.get(model_key, f'BAAI/{model_key}')
logger.info(f"[REMOTE] Using remote model: {resolved_model_path}")
# If local path was configured but doesn't exist, we'll download to that location
if local_path and isinstance(local_path, str):
logger.info(f"[DOWNLOAD] Will download to configured path: {local_path}")
# Create directory if it doesn't exist
os.makedirs(local_path, exist_ok=True)
# Resolve cache directory
if cache_dir is not None:
resolved_cache_dir = cache_dir
else:
cache_dir_raw = config_instance.get('model_paths', 'cache_dir', './cache/data')
if isinstance(cache_dir_raw, list):
resolved_cache_dir = cache_dir_raw[0] if cache_dir_raw else './cache/data'
else:
resolved_cache_dir = cache_dir_raw if cache_dir_raw else './cache/data'
# Ensure cache directory exists
if isinstance(resolved_cache_dir, str):
os.makedirs(resolved_cache_dir, exist_ok=True)
return resolved_model_path, resolved_cache_dir
def _is_valid_model_directory(path: str) -> bool:
"""Check if a directory contains a valid model"""
required_files = ['config.json']
optional_files = ['tokenizer_config.json', 'vocab.txt', 'tokenizer.json']
model_files = ['pytorch_model.bin', 'model.safetensors', 'pytorch_model.bin.index.json']
path_obj = Path(path)
if not path_obj.exists() or not path_obj.is_dir():
return False
# Check for required files
for file in required_files:
if not (path_obj / file).exists():
return False
# Check for at least one model file
has_model_file = any((path_obj / file).exists() for file in model_files)
if not has_model_file:
return False
# Check for at least one tokenizer file (for embedding models)
has_tokenizer_file = any((path_obj / file).exists() for file in optional_files)
if not has_tokenizer_file:
logger.warning(f"Model directory {path} missing tokenizer files, but proceeding anyway")
return True
def download_model_if_needed(
model_name_or_path: str,
local_path: Optional[str] = None,
cache_dir: Optional[str] = None,
model_source: str = 'huggingface'
) -> str:
"""
Download model if it's a remote identifier and local path is specified
Args:
model_name_or_path: Model identifier or path
local_path: Target local path for download
cache_dir: Cache directory
model_source: 'huggingface' or 'modelscope'
Returns:
Final model path (local if downloaded, original if already local)
"""
# If it's already a local path, return as is
if os.path.exists(model_name_or_path):
return model_name_or_path
# If no local path specified, return original (will be handled by transformers)
if not local_path:
return model_name_or_path
# Check if local path already exists and is valid
if os.path.exists(local_path) and _is_valid_model_directory(local_path):
logger.info(f"Model already exists locally: {local_path}")
return local_path
# Download the model
try:
logger.info(f"Downloading model {model_name_or_path} to {local_path}")
if model_source == 'huggingface':
from transformers import AutoModel, AutoTokenizer
# Download model and tokenizer
model = AutoModel.from_pretrained(
model_name_or_path,
cache_dir=cache_dir,
trust_remote_code=True
)
tokenizer = AutoTokenizer.from_pretrained(
model_name_or_path,
cache_dir=cache_dir,
trust_remote_code=True
)
# Save to local path
os.makedirs(local_path, exist_ok=True)
model.save_pretrained(local_path)
tokenizer.save_pretrained(local_path)
logger.info(f"Model downloaded successfully to {local_path}")
return local_path
elif model_source == 'modelscope':
try:
from modelscope import snapshot_download # type: ignore
# Download using ModelScope
downloaded_path = snapshot_download(
model_name_or_path,
cache_dir=cache_dir,
local_dir=local_path
)
logger.info(f"Model downloaded successfully from ModelScope to {downloaded_path}")
return downloaded_path
except ImportError:
logger.warning("ModelScope not available, falling back to HuggingFace")
return download_model_if_needed(
model_name_or_path, local_path, cache_dir, 'huggingface'
)
except Exception as e:
logger.error(f"Failed to download model {model_name_or_path}: {e}")
logger.info("Returning original model path - transformers will handle caching")
# Fallback return
return model_name_or_path
class ConfigLoader:
"""Centralized configuration management for BGE fine-tuning
Parameter Precedence (highest to lowest):
1. Command-line arguments (CLI)
2. Model-specific config ([m3], [reranker])
3. [hardware] section in config.toml
4. [training] section in config.toml
5. Hardcoded defaults in code
"""
_instance = None
_config = None
def __new__(cls):
if cls._instance is None:
cls._instance = super(ConfigLoader, cls).__new__(cls)
return cls._instance
def __init__(self):
if self._config is None:
self._load_config()
def _find_config_file(self) -> str:
"""Find config.toml file in project hierarchy"""
# Check environment variable FIRST (highest priority)
if 'BGE_CONFIG_PATH' in os.environ:
config_path = Path(os.environ['BGE_CONFIG_PATH'])
if config_path.exists():
return str(config_path)
else:
logger.warning(f"BGE_CONFIG_PATH set to {config_path} but file doesn't exist, falling back to default")
current_dir = Path(__file__).parent
# Search up the directory tree for config.toml
for _ in range(5): # Limit search depth
config_path = current_dir / "config.toml"
if config_path.exists():
return str(config_path)
current_dir = current_dir.parent
# Default location
return os.path.join(os.getcwd(), "config.toml")
def _load_config(self):
"""Load configuration from TOML file"""
config_path = self._find_config_file()
try:
if os.path.exists(config_path):
with open(config_path, 'r', encoding='utf-8') as f:
loaded = toml.load(f)
if loaded is None:
loaded = {}
self._config = dict(loaded)
logger.info(f"Loaded configuration from {config_path}")
else:
logger.warning(f"Config file not found at {config_path}, using defaults")
self._config = self._get_default_config()
self._save_default_config(config_path)
logger.warning("Default config created. Please check your configuration file.")
except Exception as e:
logger.error(f"Error loading config: {e}")
self._config = self._get_default_config()
# Validation for unset/placeholder values
self._validate_config()
def _validate_config(self):
"""Validate config for unset/placeholder values"""
api_key = self.get('api', 'api_key', None)
if not api_key or api_key == 'your-api-key-here':
logger.warning("API key is unset or placeholder. Set it in config.toml or via the BGE_API_KEY environment variable.")
def _get_default_config(self) -> Dict[str, Any]:
"""Get default configuration"""
return {
'api': {
'base_url': 'https://api.deepseek.com/v1',
'api_key': 'your-api-key-here',
'model': 'deepseek-chat',
'max_retries': 3,
'timeout': 30,
'temperature': 0.7
},
'training': {
'default_batch_size': 4,
'default_learning_rate': 1e-5,
'default_num_epochs': 3,
'default_warmup_ratio': 0.1,
'default_weight_decay': 0.01,
'gradient_accumulation_steps': 4
},
'model_paths': {
'bge_m3': './models/bge-m3',
'bge_reranker': './models/bge-reranker-base',
'cache_dir': './cache/data'
},
'data': {
'max_query_length': 64,
'max_passage_length': 512,
'max_seq_length': 512,
'validation_split': 0.1,
'random_seed': 42
},
'augmentation': {
'enable_query_augmentation': False,
'enable_passage_augmentation': False,
'augmentation_methods': ['paraphrase', 'back_translation'],
'num_augmentations_per_sample': 2,
'augmentation_temperature': 0.8
},
'hard_negative_mining': {
'enable_mining': True,
'num_hard_negatives': 15,
'negative_range_min': 10,
'negative_range_max': 100,
'margin': 0.1,
'index_refresh_interval': 1000,
'index_type': 'IVF',
'nlist': 100
},
'evaluation': {
'eval_batch_size': 32,
'k_values': [1, 3, 5, 10, 20, 50, 100],
'metrics': ['recall', 'precision', 'map', 'mrr', 'ndcg'],
'save_predictions': True,
'predictions_dir': './output/predictions'
},
'logging': {
'log_level': 'INFO',
'log_to_file': True,
'log_dir': './logs',
'tensorboard_dir': './logs/tensorboard',
'wandb_project': 'bge-finetuning',
'wandb_entity': '',
'use_wandb': False
},
'distributed': {
'backend': 'nccl',
'init_method': 'env://',
'find_unused_parameters': True
},
'optimization': {
'gradient_checkpointing': True,
'fp16': True,
'bf16': False,
'fp16_opt_level': 'O1',
'max_grad_norm': 1.0,
'adam_epsilon': 1e-8,
'adam_beta1': 0.9,
'adam_beta2': 0.999
},
'hardware': {
'cuda_visible_devices': '0,1,2,3',
'dataloader_num_workers': 4,
'dataloader_pin_memory': True,
'per_device_train_batch_size': 4,
'per_device_eval_batch_size': 8
},
'checkpoint': {
'save_strategy': 'steps',
'save_steps': 1000,
'save_total_limit': 3,
'load_best_model_at_end': True,
'metric_for_best_model': 'eval_recall@10',
'greater_is_better': True,
'resume_from_checkpoint': ''
},
'export': {
'export_format': ['pytorch', 'onnx'],
'optimize_for_inference': True,
'quantize': False,
'quantization_bits': 8
}
}
def _save_default_config(self, config_path: str):
"""Save default configuration to file"""
try:
os.makedirs(os.path.dirname(config_path), exist_ok=True)
with open(config_path, 'w', encoding='utf-8') as f:
config_to_save = self._config if isinstance(self._config, dict) else {}
toml.dump(config_to_save, f)
logger.info(f"Created default config at {config_path}")
except Exception as e:
logger.error(f"Error saving default config: {e}")
def get(self, section: str, key: str, default=None):
"""Get a config value, logging its source (CLI, config, default)."""
# CLI overrides are handled outside this loader; here we check config and default
value = default
source = 'default'
if isinstance(self._config, dict) and section in self._config and key in self._config[section]:
value = self._config[section][key]
source = 'config.toml'
# Robust list/tuple conversion
if isinstance(value, str) and (',' in value or ';' in value):
try:
value = [v.strip() for v in value.replace(';', ',').split(',') if v.strip()]
source += ' (parsed as list)'
except Exception as e:
logging.warning(f"Failed to parse list for {section}.{key}: {e}")
log_func = logging.info if source != 'default' else logging.warning
log_func(f"Config param {section}.{key} = {value} (source: {source})")
return value
def get_section(self, section: str) -> Dict[str, Any]:
"""Get entire configuration section"""
if isinstance(self._config, dict):
return self._config.get(section, {})
return {}
def update(self, key: str, value: Any):
"""Update configuration value. Logs a warning if overriding an existing value."""
if self._config is None:
self._config = {}
keys = key.split('.')
config = self._config
# Navigate to the parent of the final key
for k in keys[:-1]:
if k not in config:
config[k] = {}
config = config[k]
# Log if overriding
old_value = config.get(keys[-1], None)
if old_value is not None and old_value != value:
logger.warning(f"Parameter '{key}' overridden: {old_value} -> {value} (higher-precedence source)")
config[keys[-1]] = value
def reload(self):
"""Reload configuration from file"""
self._config = None
self._load_config()
# Global instance
config = ConfigLoader()
def get_config() -> ConfigLoader:
"""Get global configuration instance"""
return config
def load_config_for_training(
config_path: Optional[str] = None,
override_params: Optional[Dict[str, Any]] = None
) -> Dict[str, Any]:
"""
Load configuration specifically for training scripts
Args:
config_path: Optional path to config file
override_params: Parameters to override from command line
Returns:
Dictionary with all configuration parameters
"""
global config
# Reload with specific path if provided
if config_path:
old_path = os.environ.get('BGE_CONFIG_PATH')
os.environ['BGE_CONFIG_PATH'] = config_path
config.reload()
if old_path:
os.environ['BGE_CONFIG_PATH'] = old_path
else:
del os.environ['BGE_CONFIG_PATH']
# Get all configuration
training_config = {
'api': config.get_section('api'),
'training': config.get_section('training'),
'model_paths': config.get_section('model_paths'),
'data': config.get_section('data'),
'augmentation': config.get_section('augmentation'),
'hard_negative_mining': config.get_section('hard_negative_mining'),
'evaluation': config.get_section('evaluation'),
'logging': config.get_section('logging'),
'distributed': config.get_section('distributed'),
'optimization': config.get_section('optimization'),
'hardware': config.get_section('hardware'),
'checkpoint': config.get_section('checkpoint'),
'export': config.get_section('export'),
'm3': config.get_section('m3'),
'reranker': config.get_section('reranker'),
}
# Apply overrides if provided
if override_params:
for key, value in override_params.items():
config.update(key, value)
# Also update the returned dict
keys = key.split('.')
current = training_config
for k in keys[:-1]:
if k not in current:
current[k] = {}
current = current[k]
current[keys[-1]] = value
return training_config