519 lines
19 KiB
Python
519 lines
19 KiB
Python
"""
|
|
Centralized configuration loader for BGE fine-tuning project
|
|
"""
|
|
import os
|
|
import toml
|
|
import logging
|
|
from typing import Dict, Any, Optional, Tuple
|
|
from pathlib import Path
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def resolve_model_path(
|
|
model_key: str,
|
|
config_instance: 'ConfigLoader',
|
|
model_name_or_path: Optional[str] = None,
|
|
cache_dir: Optional[str] = None
|
|
) -> Tuple[str, str]:
|
|
"""
|
|
Resolve model path with automatic local/remote fallback
|
|
|
|
Args:
|
|
model_key: Key in model_paths config (e.g., 'bge_m3', 'bge_reranker')
|
|
config_instance: ConfigLoader instance
|
|
model_name_or_path: Optional override from CLI
|
|
cache_dir: Optional cache directory
|
|
|
|
Returns:
|
|
Tuple of (resolved_model_path, resolved_cache_dir)
|
|
"""
|
|
# Default remote model names for each model type
|
|
remote_models = {
|
|
'bge_m3': 'BAAI/bge-m3',
|
|
'bge_reranker': 'BAAI/bge-reranker-base',
|
|
}
|
|
|
|
# Get model source (huggingface or modelscope)
|
|
model_source = config_instance.get('model_paths', 'source', 'huggingface')
|
|
|
|
# If ModelScope is used, adjust model names
|
|
if model_source == 'modelscope':
|
|
remote_models = {
|
|
'bge_m3': 'damo/nlp_gte_sentence-embedding_chinese-base',
|
|
'bge_reranker': 'damo/nlp_corom_passage-ranking_english-base',
|
|
}
|
|
|
|
# 1. Check CLI override first (only if explicitly provided)
|
|
if model_name_or_path is not None:
|
|
resolved_model_path = model_name_or_path
|
|
logger.info(f"[CLI OVERRIDE] Using model path: {resolved_model_path}")
|
|
else:
|
|
# 2. Check local path from config
|
|
local_path_raw = config_instance.get('model_paths', model_key, None)
|
|
|
|
# Handle list/string conversion
|
|
if isinstance(local_path_raw, list):
|
|
local_path = local_path_raw[0] if local_path_raw else None
|
|
else:
|
|
local_path = local_path_raw
|
|
|
|
if local_path and isinstance(local_path, str) and os.path.exists(local_path):
|
|
# Check if it's a valid model directory
|
|
if _is_valid_model_directory(local_path):
|
|
resolved_model_path = local_path
|
|
logger.info(f"[LOCAL] Using local model: {resolved_model_path}")
|
|
else:
|
|
logger.warning(f"[LOCAL] Invalid model directory: {local_path}, falling back to remote")
|
|
resolved_model_path = remote_models.get(model_key, f'BAAI/{model_key}')
|
|
else:
|
|
# 3. Fall back to remote model
|
|
resolved_model_path = remote_models.get(model_key, f'BAAI/{model_key}')
|
|
logger.info(f"[REMOTE] Using remote model: {resolved_model_path}")
|
|
|
|
# If local path was configured but doesn't exist, we'll download to that location
|
|
if local_path and isinstance(local_path, str):
|
|
logger.info(f"[DOWNLOAD] Will download to configured path: {local_path}")
|
|
# Create directory if it doesn't exist
|
|
os.makedirs(local_path, exist_ok=True)
|
|
|
|
# Resolve cache directory
|
|
if cache_dir is not None:
|
|
resolved_cache_dir = cache_dir
|
|
else:
|
|
cache_dir_raw = config_instance.get('model_paths', 'cache_dir', './cache/data')
|
|
if isinstance(cache_dir_raw, list):
|
|
resolved_cache_dir = cache_dir_raw[0] if cache_dir_raw else './cache/data'
|
|
else:
|
|
resolved_cache_dir = cache_dir_raw if cache_dir_raw else './cache/data'
|
|
|
|
# Ensure cache directory exists
|
|
if isinstance(resolved_cache_dir, str):
|
|
os.makedirs(resolved_cache_dir, exist_ok=True)
|
|
|
|
return resolved_model_path, resolved_cache_dir
|
|
|
|
|
|
def _is_valid_model_directory(path: str) -> bool:
|
|
"""Check if a directory contains a valid model"""
|
|
required_files = ['config.json']
|
|
optional_files = ['tokenizer_config.json', 'vocab.txt', 'tokenizer.json']
|
|
model_files = ['pytorch_model.bin', 'model.safetensors', 'pytorch_model.bin.index.json']
|
|
|
|
path_obj = Path(path)
|
|
if not path_obj.exists() or not path_obj.is_dir():
|
|
return False
|
|
|
|
# Check for required files
|
|
for file in required_files:
|
|
if not (path_obj / file).exists():
|
|
return False
|
|
|
|
# Check for at least one model file
|
|
has_model_file = any((path_obj / file).exists() for file in model_files)
|
|
if not has_model_file:
|
|
return False
|
|
|
|
# Check for at least one tokenizer file (for embedding models)
|
|
has_tokenizer_file = any((path_obj / file).exists() for file in optional_files)
|
|
if not has_tokenizer_file:
|
|
logger.warning(f"Model directory {path} missing tokenizer files, but proceeding anyway")
|
|
|
|
return True
|
|
|
|
|
|
def download_model_if_needed(
|
|
model_name_or_path: str,
|
|
local_path: Optional[str] = None,
|
|
cache_dir: Optional[str] = None,
|
|
model_source: str = 'huggingface'
|
|
) -> str:
|
|
"""
|
|
Download model if it's a remote identifier and local path is specified
|
|
|
|
Args:
|
|
model_name_or_path: Model identifier or path
|
|
local_path: Target local path for download
|
|
cache_dir: Cache directory
|
|
model_source: 'huggingface' or 'modelscope'
|
|
|
|
Returns:
|
|
Final model path (local if downloaded, original if already local)
|
|
"""
|
|
# If it's already a local path, return as is
|
|
if os.path.exists(model_name_or_path):
|
|
return model_name_or_path
|
|
|
|
# If no local path specified, return original (will be handled by transformers)
|
|
if not local_path:
|
|
return model_name_or_path
|
|
|
|
# Check if local path already exists and is valid
|
|
if os.path.exists(local_path) and _is_valid_model_directory(local_path):
|
|
logger.info(f"Model already exists locally: {local_path}")
|
|
return local_path
|
|
|
|
# Download the model
|
|
try:
|
|
logger.info(f"Downloading model {model_name_or_path} to {local_path}")
|
|
|
|
if model_source == 'huggingface':
|
|
from transformers import AutoModel, AutoTokenizer
|
|
|
|
# Download model and tokenizer
|
|
model = AutoModel.from_pretrained(
|
|
model_name_or_path,
|
|
cache_dir=cache_dir,
|
|
trust_remote_code=True
|
|
)
|
|
tokenizer = AutoTokenizer.from_pretrained(
|
|
model_name_or_path,
|
|
cache_dir=cache_dir,
|
|
trust_remote_code=True
|
|
)
|
|
|
|
# Save to local path
|
|
os.makedirs(local_path, exist_ok=True)
|
|
model.save_pretrained(local_path)
|
|
tokenizer.save_pretrained(local_path)
|
|
|
|
logger.info(f"Model downloaded successfully to {local_path}")
|
|
return local_path
|
|
|
|
elif model_source == 'modelscope':
|
|
try:
|
|
from modelscope import snapshot_download # type: ignore
|
|
|
|
# Download using ModelScope
|
|
downloaded_path = snapshot_download(
|
|
model_name_or_path,
|
|
cache_dir=cache_dir,
|
|
local_dir=local_path
|
|
)
|
|
|
|
logger.info(f"Model downloaded successfully from ModelScope to {downloaded_path}")
|
|
return downloaded_path
|
|
|
|
except ImportError:
|
|
logger.warning("ModelScope not available, falling back to HuggingFace")
|
|
return download_model_if_needed(
|
|
model_name_or_path, local_path, cache_dir, 'huggingface'
|
|
)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to download model {model_name_or_path}: {e}")
|
|
logger.info("Returning original model path - transformers will handle caching")
|
|
|
|
# Fallback return
|
|
return model_name_or_path
|
|
|
|
|
|
class ConfigLoader:
|
|
"""Centralized configuration management for BGE fine-tuning
|
|
|
|
Parameter Precedence (highest to lowest):
|
|
1. Command-line arguments (CLI)
|
|
2. Model-specific config ([m3], [reranker])
|
|
3. [hardware] section in config.toml
|
|
4. [training] section in config.toml
|
|
5. Hardcoded defaults in code
|
|
"""
|
|
|
|
_instance = None
|
|
_config = None
|
|
|
|
def __new__(cls):
|
|
if cls._instance is None:
|
|
cls._instance = super(ConfigLoader, cls).__new__(cls)
|
|
return cls._instance
|
|
|
|
def __init__(self):
|
|
if self._config is None:
|
|
self._load_config()
|
|
|
|
def _find_config_file(self) -> str:
|
|
"""Find config.toml file in project hierarchy"""
|
|
# Check environment variable FIRST (highest priority)
|
|
if 'BGE_CONFIG_PATH' in os.environ:
|
|
config_path = Path(os.environ['BGE_CONFIG_PATH'])
|
|
if config_path.exists():
|
|
return str(config_path)
|
|
else:
|
|
logger.warning(f"BGE_CONFIG_PATH set to {config_path} but file doesn't exist, falling back to default")
|
|
|
|
current_dir = Path(__file__).parent
|
|
|
|
# Search up the directory tree for config.toml
|
|
for _ in range(5): # Limit search depth
|
|
config_path = current_dir / "config.toml"
|
|
if config_path.exists():
|
|
return str(config_path)
|
|
current_dir = current_dir.parent
|
|
|
|
# Default location
|
|
return os.path.join(os.getcwd(), "config.toml")
|
|
|
|
def _load_config(self):
|
|
"""Load configuration from TOML file"""
|
|
config_path = self._find_config_file()
|
|
|
|
try:
|
|
if os.path.exists(config_path):
|
|
with open(config_path, 'r', encoding='utf-8') as f:
|
|
loaded = toml.load(f)
|
|
if loaded is None:
|
|
loaded = {}
|
|
self._config = dict(loaded)
|
|
logger.info(f"Loaded configuration from {config_path}")
|
|
else:
|
|
logger.warning(f"Config file not found at {config_path}, using defaults")
|
|
self._config = self._get_default_config()
|
|
self._save_default_config(config_path)
|
|
logger.warning("Default config created. Please check your configuration file.")
|
|
except Exception as e:
|
|
logger.error(f"Error loading config: {e}")
|
|
self._config = self._get_default_config()
|
|
|
|
# Validation for unset/placeholder values
|
|
self._validate_config()
|
|
|
|
def _validate_config(self):
|
|
"""Validate config for unset/placeholder values"""
|
|
api_key = self.get('api', 'api_key', None)
|
|
if not api_key or api_key == 'your-api-key-here':
|
|
logger.warning("API key is unset or placeholder. Set it in config.toml or via the BGE_API_KEY environment variable.")
|
|
|
|
def _get_default_config(self) -> Dict[str, Any]:
|
|
"""Get default configuration"""
|
|
return {
|
|
'api': {
|
|
'base_url': 'https://api.deepseek.com/v1',
|
|
'api_key': 'your-api-key-here',
|
|
'model': 'deepseek-chat',
|
|
'max_retries': 3,
|
|
'timeout': 30,
|
|
'temperature': 0.7
|
|
},
|
|
'training': {
|
|
'default_batch_size': 4,
|
|
'default_learning_rate': 1e-5,
|
|
'default_num_epochs': 3,
|
|
'default_warmup_ratio': 0.1,
|
|
'default_weight_decay': 0.01,
|
|
'gradient_accumulation_steps': 4
|
|
},
|
|
'model_paths': {
|
|
'bge_m3': './models/bge-m3',
|
|
'bge_reranker': './models/bge-reranker-base',
|
|
'cache_dir': './cache/data'
|
|
},
|
|
'data': {
|
|
'max_query_length': 64,
|
|
'max_passage_length': 512,
|
|
'max_seq_length': 512,
|
|
'validation_split': 0.1,
|
|
'random_seed': 42
|
|
},
|
|
'augmentation': {
|
|
'enable_query_augmentation': False,
|
|
'enable_passage_augmentation': False,
|
|
'augmentation_methods': ['paraphrase', 'back_translation'],
|
|
'num_augmentations_per_sample': 2,
|
|
'augmentation_temperature': 0.8
|
|
},
|
|
'hard_negative_mining': {
|
|
'enable_mining': True,
|
|
'num_hard_negatives': 15,
|
|
'negative_range_min': 10,
|
|
'negative_range_max': 100,
|
|
'margin': 0.1,
|
|
'index_refresh_interval': 1000,
|
|
'index_type': 'IVF',
|
|
'nlist': 100
|
|
},
|
|
'evaluation': {
|
|
'eval_batch_size': 32,
|
|
'k_values': [1, 3, 5, 10, 20, 50, 100],
|
|
'metrics': ['recall', 'precision', 'map', 'mrr', 'ndcg'],
|
|
'save_predictions': True,
|
|
'predictions_dir': './output/predictions'
|
|
},
|
|
'logging': {
|
|
'log_level': 'INFO',
|
|
'log_to_file': True,
|
|
'log_dir': './logs',
|
|
'tensorboard_dir': './logs/tensorboard',
|
|
'wandb_project': 'bge-finetuning',
|
|
'wandb_entity': '',
|
|
'use_wandb': False
|
|
},
|
|
'distributed': {
|
|
'backend': 'nccl',
|
|
'init_method': 'env://',
|
|
'find_unused_parameters': True
|
|
},
|
|
'optimization': {
|
|
'gradient_checkpointing': True,
|
|
'fp16': True,
|
|
'bf16': False,
|
|
'fp16_opt_level': 'O1',
|
|
'max_grad_norm': 1.0,
|
|
'adam_epsilon': 1e-8,
|
|
'adam_beta1': 0.9,
|
|
'adam_beta2': 0.999
|
|
},
|
|
'hardware': {
|
|
'cuda_visible_devices': '0,1,2,3',
|
|
'dataloader_num_workers': 4,
|
|
'dataloader_pin_memory': True,
|
|
'per_device_train_batch_size': 4,
|
|
'per_device_eval_batch_size': 8
|
|
},
|
|
'checkpoint': {
|
|
'save_strategy': 'steps',
|
|
'save_steps': 1000,
|
|
'save_total_limit': 3,
|
|
'load_best_model_at_end': True,
|
|
'metric_for_best_model': 'eval_recall@10',
|
|
'greater_is_better': True,
|
|
'resume_from_checkpoint': ''
|
|
},
|
|
'export': {
|
|
'export_format': ['pytorch', 'onnx'],
|
|
'optimize_for_inference': True,
|
|
'quantize': False,
|
|
'quantization_bits': 8
|
|
}
|
|
}
|
|
|
|
def _save_default_config(self, config_path: str):
|
|
"""Save default configuration to file"""
|
|
try:
|
|
os.makedirs(os.path.dirname(config_path), exist_ok=True)
|
|
with open(config_path, 'w', encoding='utf-8') as f:
|
|
config_to_save = self._config if isinstance(self._config, dict) else {}
|
|
toml.dump(config_to_save, f)
|
|
logger.info(f"Created default config at {config_path}")
|
|
except Exception as e:
|
|
logger.error(f"Error saving default config: {e}")
|
|
|
|
def get(self, section: str, key: str, default=None):
|
|
"""Get a config value, logging its source (CLI, config, default)."""
|
|
# CLI overrides are handled outside this loader; here we check config and default
|
|
value = default
|
|
source = 'default'
|
|
if isinstance(self._config, dict) and section in self._config and key in self._config[section]:
|
|
value = self._config[section][key]
|
|
source = 'config.toml'
|
|
# Robust list/tuple conversion
|
|
if isinstance(value, str) and (',' in value or ';' in value):
|
|
try:
|
|
value = [v.strip() for v in value.replace(';', ',').split(',') if v.strip()]
|
|
source += ' (parsed as list)'
|
|
except Exception as e:
|
|
logging.warning(f"Failed to parse list for {section}.{key}: {e}")
|
|
log_func = logging.info if source != 'default' else logging.warning
|
|
log_func(f"Config param {section}.{key} = {value} (source: {source})")
|
|
return value
|
|
|
|
def get_section(self, section: str) -> Dict[str, Any]:
|
|
"""Get entire configuration section"""
|
|
if isinstance(self._config, dict):
|
|
return self._config.get(section, {})
|
|
return {}
|
|
|
|
def update(self, key: str, value: Any):
|
|
"""Update configuration value. Logs a warning if overriding an existing value."""
|
|
if self._config is None:
|
|
self._config = {}
|
|
|
|
keys = key.split('.')
|
|
config = self._config
|
|
|
|
# Navigate to the parent of the final key
|
|
for k in keys[:-1]:
|
|
if k not in config:
|
|
config[k] = {}
|
|
config = config[k]
|
|
|
|
# Log if overriding
|
|
old_value = config.get(keys[-1], None)
|
|
if old_value is not None and old_value != value:
|
|
logger.warning(f"Parameter '{key}' overridden: {old_value} -> {value} (higher-precedence source)")
|
|
config[keys[-1]] = value
|
|
|
|
def reload(self):
|
|
"""Reload configuration from file"""
|
|
self._config = None
|
|
self._load_config()
|
|
|
|
|
|
# Global instance
|
|
config = ConfigLoader()
|
|
|
|
|
|
def get_config() -> ConfigLoader:
|
|
"""Get global configuration instance"""
|
|
return config
|
|
|
|
|
|
def load_config_for_training(
|
|
config_path: Optional[str] = None,
|
|
override_params: Optional[Dict[str, Any]] = None
|
|
) -> Dict[str, Any]:
|
|
"""
|
|
Load configuration specifically for training scripts
|
|
|
|
Args:
|
|
config_path: Optional path to config file
|
|
override_params: Parameters to override from command line
|
|
|
|
Returns:
|
|
Dictionary with all configuration parameters
|
|
"""
|
|
global config
|
|
|
|
# Reload with specific path if provided
|
|
if config_path:
|
|
old_path = os.environ.get('BGE_CONFIG_PATH')
|
|
os.environ['BGE_CONFIG_PATH'] = config_path
|
|
config.reload()
|
|
if old_path:
|
|
os.environ['BGE_CONFIG_PATH'] = old_path
|
|
else:
|
|
del os.environ['BGE_CONFIG_PATH']
|
|
|
|
# Get all configuration
|
|
training_config = {
|
|
'api': config.get_section('api'),
|
|
'training': config.get_section('training'),
|
|
'model_paths': config.get_section('model_paths'),
|
|
'data': config.get_section('data'),
|
|
'augmentation': config.get_section('augmentation'),
|
|
'hard_negative_mining': config.get_section('hard_negative_mining'),
|
|
'evaluation': config.get_section('evaluation'),
|
|
'logging': config.get_section('logging'),
|
|
'distributed': config.get_section('distributed'),
|
|
'optimization': config.get_section('optimization'),
|
|
'hardware': config.get_section('hardware'),
|
|
'checkpoint': config.get_section('checkpoint'),
|
|
'export': config.get_section('export'),
|
|
'm3': config.get_section('m3'),
|
|
'reranker': config.get_section('reranker'),
|
|
}
|
|
|
|
# Apply overrides if provided
|
|
if override_params:
|
|
for key, value in override_params.items():
|
|
config.update(key, value)
|
|
# Also update the returned dict
|
|
keys = key.split('.')
|
|
current = training_config
|
|
for k in keys[:-1]:
|
|
if k not in current:
|
|
current[k] = {}
|
|
current = current[k]
|
|
current[keys[-1]] = value
|
|
|
|
return training_config
|