init commit
This commit is contained in:
518
utils/config_loader.py
Normal file
518
utils/config_loader.py
Normal file
@@ -0,0 +1,518 @@
|
||||
"""
|
||||
Centralized configuration loader for BGE fine-tuning project
|
||||
"""
|
||||
import os
|
||||
import toml
|
||||
import logging
|
||||
from typing import Dict, Any, Optional, Tuple
|
||||
from pathlib import Path
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def resolve_model_path(
|
||||
model_key: str,
|
||||
config_instance: 'ConfigLoader',
|
||||
model_name_or_path: Optional[str] = None,
|
||||
cache_dir: Optional[str] = None
|
||||
) -> Tuple[str, str]:
|
||||
"""
|
||||
Resolve model path with automatic local/remote fallback
|
||||
|
||||
Args:
|
||||
model_key: Key in model_paths config (e.g., 'bge_m3', 'bge_reranker')
|
||||
config_instance: ConfigLoader instance
|
||||
model_name_or_path: Optional override from CLI
|
||||
cache_dir: Optional cache directory
|
||||
|
||||
Returns:
|
||||
Tuple of (resolved_model_path, resolved_cache_dir)
|
||||
"""
|
||||
# Default remote model names for each model type
|
||||
remote_models = {
|
||||
'bge_m3': 'BAAI/bge-m3',
|
||||
'bge_reranker': 'BAAI/bge-reranker-base',
|
||||
}
|
||||
|
||||
# Get model source (huggingface or modelscope)
|
||||
model_source = config_instance.get('model_paths', 'source', 'huggingface')
|
||||
|
||||
# If ModelScope is used, adjust model names
|
||||
if model_source == 'modelscope':
|
||||
remote_models = {
|
||||
'bge_m3': 'damo/nlp_gte_sentence-embedding_chinese-base',
|
||||
'bge_reranker': 'damo/nlp_corom_passage-ranking_english-base',
|
||||
}
|
||||
|
||||
# 1. Check CLI override first (only if explicitly provided)
|
||||
if model_name_or_path is not None:
|
||||
resolved_model_path = model_name_or_path
|
||||
logger.info(f"[CLI OVERRIDE] Using model path: {resolved_model_path}")
|
||||
else:
|
||||
# 2. Check local path from config
|
||||
local_path_raw = config_instance.get('model_paths', model_key, None)
|
||||
|
||||
# Handle list/string conversion
|
||||
if isinstance(local_path_raw, list):
|
||||
local_path = local_path_raw[0] if local_path_raw else None
|
||||
else:
|
||||
local_path = local_path_raw
|
||||
|
||||
if local_path and isinstance(local_path, str) and os.path.exists(local_path):
|
||||
# Check if it's a valid model directory
|
||||
if _is_valid_model_directory(local_path):
|
||||
resolved_model_path = local_path
|
||||
logger.info(f"[LOCAL] Using local model: {resolved_model_path}")
|
||||
else:
|
||||
logger.warning(f"[LOCAL] Invalid model directory: {local_path}, falling back to remote")
|
||||
resolved_model_path = remote_models.get(model_key, f'BAAI/{model_key}')
|
||||
else:
|
||||
# 3. Fall back to remote model
|
||||
resolved_model_path = remote_models.get(model_key, f'BAAI/{model_key}')
|
||||
logger.info(f"[REMOTE] Using remote model: {resolved_model_path}")
|
||||
|
||||
# If local path was configured but doesn't exist, we'll download to that location
|
||||
if local_path and isinstance(local_path, str):
|
||||
logger.info(f"[DOWNLOAD] Will download to configured path: {local_path}")
|
||||
# Create directory if it doesn't exist
|
||||
os.makedirs(local_path, exist_ok=True)
|
||||
|
||||
# Resolve cache directory
|
||||
if cache_dir is not None:
|
||||
resolved_cache_dir = cache_dir
|
||||
else:
|
||||
cache_dir_raw = config_instance.get('model_paths', 'cache_dir', './cache/data')
|
||||
if isinstance(cache_dir_raw, list):
|
||||
resolved_cache_dir = cache_dir_raw[0] if cache_dir_raw else './cache/data'
|
||||
else:
|
||||
resolved_cache_dir = cache_dir_raw if cache_dir_raw else './cache/data'
|
||||
|
||||
# Ensure cache directory exists
|
||||
if isinstance(resolved_cache_dir, str):
|
||||
os.makedirs(resolved_cache_dir, exist_ok=True)
|
||||
|
||||
return resolved_model_path, resolved_cache_dir
|
||||
|
||||
|
||||
def _is_valid_model_directory(path: str) -> bool:
|
||||
"""Check if a directory contains a valid model"""
|
||||
required_files = ['config.json']
|
||||
optional_files = ['tokenizer_config.json', 'vocab.txt', 'tokenizer.json']
|
||||
model_files = ['pytorch_model.bin', 'model.safetensors', 'pytorch_model.bin.index.json']
|
||||
|
||||
path_obj = Path(path)
|
||||
if not path_obj.exists() or not path_obj.is_dir():
|
||||
return False
|
||||
|
||||
# Check for required files
|
||||
for file in required_files:
|
||||
if not (path_obj / file).exists():
|
||||
return False
|
||||
|
||||
# Check for at least one model file
|
||||
has_model_file = any((path_obj / file).exists() for file in model_files)
|
||||
if not has_model_file:
|
||||
return False
|
||||
|
||||
# Check for at least one tokenizer file (for embedding models)
|
||||
has_tokenizer_file = any((path_obj / file).exists() for file in optional_files)
|
||||
if not has_tokenizer_file:
|
||||
logger.warning(f"Model directory {path} missing tokenizer files, but proceeding anyway")
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def download_model_if_needed(
|
||||
model_name_or_path: str,
|
||||
local_path: Optional[str] = None,
|
||||
cache_dir: Optional[str] = None,
|
||||
model_source: str = 'huggingface'
|
||||
) -> str:
|
||||
"""
|
||||
Download model if it's a remote identifier and local path is specified
|
||||
|
||||
Args:
|
||||
model_name_or_path: Model identifier or path
|
||||
local_path: Target local path for download
|
||||
cache_dir: Cache directory
|
||||
model_source: 'huggingface' or 'modelscope'
|
||||
|
||||
Returns:
|
||||
Final model path (local if downloaded, original if already local)
|
||||
"""
|
||||
# If it's already a local path, return as is
|
||||
if os.path.exists(model_name_or_path):
|
||||
return model_name_or_path
|
||||
|
||||
# If no local path specified, return original (will be handled by transformers)
|
||||
if not local_path:
|
||||
return model_name_or_path
|
||||
|
||||
# Check if local path already exists and is valid
|
||||
if os.path.exists(local_path) and _is_valid_model_directory(local_path):
|
||||
logger.info(f"Model already exists locally: {local_path}")
|
||||
return local_path
|
||||
|
||||
# Download the model
|
||||
try:
|
||||
logger.info(f"Downloading model {model_name_or_path} to {local_path}")
|
||||
|
||||
if model_source == 'huggingface':
|
||||
from transformers import AutoModel, AutoTokenizer
|
||||
|
||||
# Download model and tokenizer
|
||||
model = AutoModel.from_pretrained(
|
||||
model_name_or_path,
|
||||
cache_dir=cache_dir,
|
||||
trust_remote_code=True
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_name_or_path,
|
||||
cache_dir=cache_dir,
|
||||
trust_remote_code=True
|
||||
)
|
||||
|
||||
# Save to local path
|
||||
os.makedirs(local_path, exist_ok=True)
|
||||
model.save_pretrained(local_path)
|
||||
tokenizer.save_pretrained(local_path)
|
||||
|
||||
logger.info(f"Model downloaded successfully to {local_path}")
|
||||
return local_path
|
||||
|
||||
elif model_source == 'modelscope':
|
||||
try:
|
||||
from modelscope import snapshot_download # type: ignore
|
||||
|
||||
# Download using ModelScope
|
||||
downloaded_path = snapshot_download(
|
||||
model_name_or_path,
|
||||
cache_dir=cache_dir,
|
||||
local_dir=local_path
|
||||
)
|
||||
|
||||
logger.info(f"Model downloaded successfully from ModelScope to {downloaded_path}")
|
||||
return downloaded_path
|
||||
|
||||
except ImportError:
|
||||
logger.warning("ModelScope not available, falling back to HuggingFace")
|
||||
return download_model_if_needed(
|
||||
model_name_or_path, local_path, cache_dir, 'huggingface'
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to download model {model_name_or_path}: {e}")
|
||||
logger.info("Returning original model path - transformers will handle caching")
|
||||
|
||||
# Fallback return
|
||||
return model_name_or_path
|
||||
|
||||
|
||||
class ConfigLoader:
|
||||
"""Centralized configuration management for BGE fine-tuning
|
||||
|
||||
Parameter Precedence (highest to lowest):
|
||||
1. Command-line arguments (CLI)
|
||||
2. Model-specific config ([m3], [reranker])
|
||||
3. [hardware] section in config.toml
|
||||
4. [training] section in config.toml
|
||||
5. Hardcoded defaults in code
|
||||
"""
|
||||
|
||||
_instance = None
|
||||
_config = None
|
||||
|
||||
def __new__(cls):
|
||||
if cls._instance is None:
|
||||
cls._instance = super(ConfigLoader, cls).__new__(cls)
|
||||
return cls._instance
|
||||
|
||||
def __init__(self):
|
||||
if self._config is None:
|
||||
self._load_config()
|
||||
|
||||
def _find_config_file(self) -> str:
|
||||
"""Find config.toml file in project hierarchy"""
|
||||
# Check environment variable FIRST (highest priority)
|
||||
if 'BGE_CONFIG_PATH' in os.environ:
|
||||
config_path = Path(os.environ['BGE_CONFIG_PATH'])
|
||||
if config_path.exists():
|
||||
return str(config_path)
|
||||
else:
|
||||
logger.warning(f"BGE_CONFIG_PATH set to {config_path} but file doesn't exist, falling back to default")
|
||||
|
||||
current_dir = Path(__file__).parent
|
||||
|
||||
# Search up the directory tree for config.toml
|
||||
for _ in range(5): # Limit search depth
|
||||
config_path = current_dir / "config.toml"
|
||||
if config_path.exists():
|
||||
return str(config_path)
|
||||
current_dir = current_dir.parent
|
||||
|
||||
# Default location
|
||||
return os.path.join(os.getcwd(), "config.toml")
|
||||
|
||||
def _load_config(self):
|
||||
"""Load configuration from TOML file"""
|
||||
config_path = self._find_config_file()
|
||||
|
||||
try:
|
||||
if os.path.exists(config_path):
|
||||
with open(config_path, 'r', encoding='utf-8') as f:
|
||||
loaded = toml.load(f)
|
||||
if loaded is None:
|
||||
loaded = {}
|
||||
self._config = dict(loaded)
|
||||
logger.info(f"Loaded configuration from {config_path}")
|
||||
else:
|
||||
logger.warning(f"Config file not found at {config_path}, using defaults")
|
||||
self._config = self._get_default_config()
|
||||
self._save_default_config(config_path)
|
||||
logger.warning("Default config created. Please check your configuration file.")
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading config: {e}")
|
||||
self._config = self._get_default_config()
|
||||
|
||||
# Validation for unset/placeholder values
|
||||
self._validate_config()
|
||||
|
||||
def _validate_config(self):
|
||||
"""Validate config for unset/placeholder values"""
|
||||
api_key = self.get('api', 'api_key', None)
|
||||
if not api_key or api_key == 'your-api-key-here':
|
||||
logger.warning("API key is unset or placeholder. Set it in config.toml or via the BGE_API_KEY environment variable.")
|
||||
|
||||
def _get_default_config(self) -> Dict[str, Any]:
|
||||
"""Get default configuration"""
|
||||
return {
|
||||
'api': {
|
||||
'base_url': 'https://api.deepseek.com/v1',
|
||||
'api_key': 'your-api-key-here',
|
||||
'model': 'deepseek-chat',
|
||||
'max_retries': 3,
|
||||
'timeout': 30,
|
||||
'temperature': 0.7
|
||||
},
|
||||
'training': {
|
||||
'default_batch_size': 4,
|
||||
'default_learning_rate': 1e-5,
|
||||
'default_num_epochs': 3,
|
||||
'default_warmup_ratio': 0.1,
|
||||
'default_weight_decay': 0.01,
|
||||
'gradient_accumulation_steps': 4
|
||||
},
|
||||
'model_paths': {
|
||||
'bge_m3': './models/bge-m3',
|
||||
'bge_reranker': './models/bge-reranker-base',
|
||||
'cache_dir': './cache/data'
|
||||
},
|
||||
'data': {
|
||||
'max_query_length': 64,
|
||||
'max_passage_length': 512,
|
||||
'max_seq_length': 512,
|
||||
'validation_split': 0.1,
|
||||
'random_seed': 42
|
||||
},
|
||||
'augmentation': {
|
||||
'enable_query_augmentation': False,
|
||||
'enable_passage_augmentation': False,
|
||||
'augmentation_methods': ['paraphrase', 'back_translation'],
|
||||
'num_augmentations_per_sample': 2,
|
||||
'augmentation_temperature': 0.8
|
||||
},
|
||||
'hard_negative_mining': {
|
||||
'enable_mining': True,
|
||||
'num_hard_negatives': 15,
|
||||
'negative_range_min': 10,
|
||||
'negative_range_max': 100,
|
||||
'margin': 0.1,
|
||||
'index_refresh_interval': 1000,
|
||||
'index_type': 'IVF',
|
||||
'nlist': 100
|
||||
},
|
||||
'evaluation': {
|
||||
'eval_batch_size': 32,
|
||||
'k_values': [1, 3, 5, 10, 20, 50, 100],
|
||||
'metrics': ['recall', 'precision', 'map', 'mrr', 'ndcg'],
|
||||
'save_predictions': True,
|
||||
'predictions_dir': './output/predictions'
|
||||
},
|
||||
'logging': {
|
||||
'log_level': 'INFO',
|
||||
'log_to_file': True,
|
||||
'log_dir': './logs',
|
||||
'tensorboard_dir': './logs/tensorboard',
|
||||
'wandb_project': 'bge-finetuning',
|
||||
'wandb_entity': '',
|
||||
'use_wandb': False
|
||||
},
|
||||
'distributed': {
|
||||
'backend': 'nccl',
|
||||
'init_method': 'env://',
|
||||
'find_unused_parameters': True
|
||||
},
|
||||
'optimization': {
|
||||
'gradient_checkpointing': True,
|
||||
'fp16': True,
|
||||
'bf16': False,
|
||||
'fp16_opt_level': 'O1',
|
||||
'max_grad_norm': 1.0,
|
||||
'adam_epsilon': 1e-8,
|
||||
'adam_beta1': 0.9,
|
||||
'adam_beta2': 0.999
|
||||
},
|
||||
'hardware': {
|
||||
'cuda_visible_devices': '0,1,2,3',
|
||||
'dataloader_num_workers': 4,
|
||||
'dataloader_pin_memory': True,
|
||||
'per_device_train_batch_size': 4,
|
||||
'per_device_eval_batch_size': 8
|
||||
},
|
||||
'checkpoint': {
|
||||
'save_strategy': 'steps',
|
||||
'save_steps': 1000,
|
||||
'save_total_limit': 3,
|
||||
'load_best_model_at_end': True,
|
||||
'metric_for_best_model': 'eval_recall@10',
|
||||
'greater_is_better': True,
|
||||
'resume_from_checkpoint': ''
|
||||
},
|
||||
'export': {
|
||||
'export_format': ['pytorch', 'onnx'],
|
||||
'optimize_for_inference': True,
|
||||
'quantize': False,
|
||||
'quantization_bits': 8
|
||||
}
|
||||
}
|
||||
|
||||
def _save_default_config(self, config_path: str):
|
||||
"""Save default configuration to file"""
|
||||
try:
|
||||
os.makedirs(os.path.dirname(config_path), exist_ok=True)
|
||||
with open(config_path, 'w', encoding='utf-8') as f:
|
||||
config_to_save = self._config if isinstance(self._config, dict) else {}
|
||||
toml.dump(config_to_save, f)
|
||||
logger.info(f"Created default config at {config_path}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving default config: {e}")
|
||||
|
||||
def get(self, section: str, key: str, default=None):
|
||||
"""Get a config value, logging its source (CLI, config, default)."""
|
||||
# CLI overrides are handled outside this loader; here we check config and default
|
||||
value = default
|
||||
source = 'default'
|
||||
if isinstance(self._config, dict) and section in self._config and key in self._config[section]:
|
||||
value = self._config[section][key]
|
||||
source = 'config.toml'
|
||||
# Robust list/tuple conversion
|
||||
if isinstance(value, str) and (',' in value or ';' in value):
|
||||
try:
|
||||
value = [v.strip() for v in value.replace(';', ',').split(',') if v.strip()]
|
||||
source += ' (parsed as list)'
|
||||
except Exception as e:
|
||||
logging.warning(f"Failed to parse list for {section}.{key}: {e}")
|
||||
log_func = logging.info if source != 'default' else logging.warning
|
||||
log_func(f"Config param {section}.{key} = {value} (source: {source})")
|
||||
return value
|
||||
|
||||
def get_section(self, section: str) -> Dict[str, Any]:
|
||||
"""Get entire configuration section"""
|
||||
if isinstance(self._config, dict):
|
||||
return self._config.get(section, {})
|
||||
return {}
|
||||
|
||||
def update(self, key: str, value: Any):
|
||||
"""Update configuration value. Logs a warning if overriding an existing value."""
|
||||
if self._config is None:
|
||||
self._config = {}
|
||||
|
||||
keys = key.split('.')
|
||||
config = self._config
|
||||
|
||||
# Navigate to the parent of the final key
|
||||
for k in keys[:-1]:
|
||||
if k not in config:
|
||||
config[k] = {}
|
||||
config = config[k]
|
||||
|
||||
# Log if overriding
|
||||
old_value = config.get(keys[-1], None)
|
||||
if old_value is not None and old_value != value:
|
||||
logger.warning(f"Parameter '{key}' overridden: {old_value} -> {value} (higher-precedence source)")
|
||||
config[keys[-1]] = value
|
||||
|
||||
def reload(self):
|
||||
"""Reload configuration from file"""
|
||||
self._config = None
|
||||
self._load_config()
|
||||
|
||||
|
||||
# Global instance
|
||||
config = ConfigLoader()
|
||||
|
||||
|
||||
def get_config() -> ConfigLoader:
|
||||
"""Get global configuration instance"""
|
||||
return config
|
||||
|
||||
|
||||
def load_config_for_training(
|
||||
config_path: Optional[str] = None,
|
||||
override_params: Optional[Dict[str, Any]] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Load configuration specifically for training scripts
|
||||
|
||||
Args:
|
||||
config_path: Optional path to config file
|
||||
override_params: Parameters to override from command line
|
||||
|
||||
Returns:
|
||||
Dictionary with all configuration parameters
|
||||
"""
|
||||
global config
|
||||
|
||||
# Reload with specific path if provided
|
||||
if config_path:
|
||||
old_path = os.environ.get('BGE_CONFIG_PATH')
|
||||
os.environ['BGE_CONFIG_PATH'] = config_path
|
||||
config.reload()
|
||||
if old_path:
|
||||
os.environ['BGE_CONFIG_PATH'] = old_path
|
||||
else:
|
||||
del os.environ['BGE_CONFIG_PATH']
|
||||
|
||||
# Get all configuration
|
||||
training_config = {
|
||||
'api': config.get_section('api'),
|
||||
'training': config.get_section('training'),
|
||||
'model_paths': config.get_section('model_paths'),
|
||||
'data': config.get_section('data'),
|
||||
'augmentation': config.get_section('augmentation'),
|
||||
'hard_negative_mining': config.get_section('hard_negative_mining'),
|
||||
'evaluation': config.get_section('evaluation'),
|
||||
'logging': config.get_section('logging'),
|
||||
'distributed': config.get_section('distributed'),
|
||||
'optimization': config.get_section('optimization'),
|
||||
'hardware': config.get_section('hardware'),
|
||||
'checkpoint': config.get_section('checkpoint'),
|
||||
'export': config.get_section('export'),
|
||||
'm3': config.get_section('m3'),
|
||||
'reranker': config.get_section('reranker'),
|
||||
}
|
||||
|
||||
# Apply overrides if provided
|
||||
if override_params:
|
||||
for key, value in override_params.items():
|
||||
config.update(key, value)
|
||||
# Also update the returned dict
|
||||
keys = key.split('.')
|
||||
current = training_config
|
||||
for k in keys[:-1]:
|
||||
if k not in current:
|
||||
current[k] = {}
|
||||
current = current[k]
|
||||
current[keys[-1]] = value
|
||||
|
||||
return training_config
|
||||
Reference in New Issue
Block a user