init commit
This commit is contained in:
13
config/__init__.py
Normal file
13
config/__init__.py
Normal file
@@ -0,0 +1,13 @@
|
||||
"""
|
||||
Configuration module for BGE fine-tuning
|
||||
"""
|
||||
from .base_config import BaseConfig
|
||||
from .m3_config import BGEM3Config
|
||||
from .reranker_config import BGERerankerConfig
|
||||
|
||||
|
||||
__all__ = [
|
||||
'BaseConfig',
|
||||
'BGEM3Config',
|
||||
'BGERerankerConfig'
|
||||
]
|
||||
560
config/base_config.py
Normal file
560
config/base_config.py
Normal file
@@ -0,0 +1,560 @@
|
||||
import os
|
||||
import torch
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional, List, Dict, Any
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class BaseConfig:
|
||||
"""
|
||||
Base configuration for all trainers and models.
|
||||
|
||||
Configuration Precedence Order:
|
||||
1. Command-line arguments (highest priority)
|
||||
2. Model-specific config sections ([m3], [reranker])
|
||||
3. [hardware] section (for hardware-specific overrides)
|
||||
4. [training] section (general training defaults)
|
||||
5. Dataclass field defaults (lowest priority)
|
||||
|
||||
Example:
|
||||
If batch_size is defined in multiple places:
|
||||
- CLI arg --batch_size 32 (wins)
|
||||
- [m3] batch_size = 16
|
||||
- [hardware] per_device_train_batch_size = 8
|
||||
- [training] default_batch_size = 4
|
||||
- Dataclass default = 2
|
||||
|
||||
Note: Model-specific configs ([m3], [reranker]) take precedence over
|
||||
[hardware] for their respective models to allow fine-grained control.
|
||||
"""
|
||||
temperature: float = 1.0 # Softmax temperature for loss functions
|
||||
use_in_batch_negatives: bool = False # Use in-batch negatives for contrastive loss
|
||||
negatives_cross_device: bool = False # Use negatives across devices (DDP)
|
||||
label_smoothing: float = 0.0 # Label smoothing for classification losses
|
||||
alpha: float = 1.0 # Weight for auxiliary losses
|
||||
gradient_accumulation_steps: int = 1 # Steps to accumulate gradients
|
||||
num_train_epochs: int = 3 # Number of training epochs
|
||||
eval_steps: int = 1 # Evaluate every N epochs
|
||||
save_steps: int = 1 # Save checkpoint every N epochs
|
||||
warmup_steps: int = 0 # Number of warmup steps for scheduler
|
||||
warmup_ratio: float = 0.0 # Warmup ratio for scheduler
|
||||
metric_for_best_model: str = "accuracy" # Metric to select best model
|
||||
checkpoint_dir: str = "./checkpoints" # Directory for saving checkpoints
|
||||
resume_from_checkpoint: str = "" # Path to resume checkpoint
|
||||
max_grad_norm: float = 1.0 # Max gradient norm for clipping
|
||||
use_amp: bool = False # Use automatic mixed precision
|
||||
|
||||
# Model configuration
|
||||
model_name_or_path: str = "BAAI/bge-m3"
|
||||
cache_dir: Optional[str] = None
|
||||
|
||||
# Data configuration
|
||||
train_data_path: str = "./data/train.jsonl"
|
||||
eval_data_path: Optional[str] = "./data/eval.jsonl"
|
||||
max_seq_length: int = 512
|
||||
query_max_length: int = 64
|
||||
passage_max_length: int = 512
|
||||
|
||||
# Training configuration
|
||||
output_dir: str = "./output"
|
||||
learning_rate: float = 1e-5
|
||||
weight_decay: float = 0.01
|
||||
adam_epsilon: float = 1e-8
|
||||
adam_beta1: float = 0.9
|
||||
adam_beta2: float = 0.999
|
||||
|
||||
# Advanced training configuration
|
||||
fp16: bool = True
|
||||
bf16: bool = False
|
||||
gradient_checkpointing: bool = True
|
||||
deepspeed: Optional[str] = None
|
||||
local_rank: int = -1
|
||||
dataloader_num_workers: int = 4
|
||||
dataloader_pin_memory: bool = True
|
||||
|
||||
# Batch sizes
|
||||
per_device_train_batch_size: int = 4
|
||||
per_device_eval_batch_size: int = 8
|
||||
|
||||
# Logging and saving
|
||||
logging_dir: str = "./logs"
|
||||
# For testing only, in production, set to 100
|
||||
logging_steps: int = 100
|
||||
save_total_limit: int = 3
|
||||
load_best_model_at_end: bool = True
|
||||
greater_is_better: bool = False
|
||||
evaluation_strategy: str = "steps"
|
||||
|
||||
# Checkpoint and resume
|
||||
overwrite_output_dir: bool = False
|
||||
|
||||
# Seed
|
||||
seed: int = 42
|
||||
|
||||
# Platform compatibility
|
||||
device: str = "auto"
|
||||
n_gpu: int = 0
|
||||
|
||||
def __post_init__(self):
|
||||
"""Post-initialization with robust config loading and device setup
|
||||
Notes:
|
||||
- [hardware] section overrides [training] for overlapping parameters.
|
||||
- Model-specific configs ([m3], [reranker]) override [training]/[hardware] for their models.
|
||||
"""
|
||||
# Load configuration from centralized config system
|
||||
config = self._get_config_safely()
|
||||
|
||||
# Apply configuration values with fallbacks
|
||||
self._apply_config_values(config)
|
||||
|
||||
# Setup device detection
|
||||
self._setup_device()
|
||||
|
||||
def _get_config_safely(self):
|
||||
"""Safely get configuration with multiple fallback strategies"""
|
||||
try:
|
||||
# Try absolute import
|
||||
from utils.config_loader import get_config
|
||||
return get_config()
|
||||
except ImportError:
|
||||
try:
|
||||
# Try absolute import (for scripts)
|
||||
from bge_finetune.utils.config_loader import get_config # type: ignore[import]
|
||||
return get_config()
|
||||
except ImportError:
|
||||
try:
|
||||
# Try direct import (when run as script from root)
|
||||
from utils.config_loader import get_config
|
||||
return get_config()
|
||||
except ImportError:
|
||||
logger.error("Could not import config_loader, using fallback configuration. Please check your environment and config loader path.")
|
||||
return self._create_fallback_config()
|
||||
|
||||
def _create_fallback_config(self):
|
||||
"""Create a minimal fallback configuration object"""
|
||||
|
||||
class FallbackConfig:
|
||||
def __init__(self):
|
||||
self._config = {
|
||||
'training': {
|
||||
'default_batch_size': 4,
|
||||
'default_learning_rate': 1e-5,
|
||||
'default_num_epochs': 3,
|
||||
'default_warmup_ratio': 0.1,
|
||||
'default_weight_decay': 0.01,
|
||||
'gradient_accumulation_steps': 4
|
||||
},
|
||||
'model_paths': {
|
||||
'cache_dir': './cache/data'
|
||||
},
|
||||
'data': {
|
||||
'max_seq_length': 512,
|
||||
'max_query_length': 64,
|
||||
'max_passage_length': 512,
|
||||
'random_seed': 42
|
||||
},
|
||||
'hardware': {
|
||||
'dataloader_num_workers': 4,
|
||||
'dataloader_pin_memory': True,
|
||||
'per_device_train_batch_size': 4,
|
||||
'per_device_eval_batch_size': 8
|
||||
},
|
||||
'optimization': {
|
||||
'fp16': True,
|
||||
'bf16': False,
|
||||
'gradient_checkpointing': True,
|
||||
'max_grad_norm': 1.0,
|
||||
'adam_epsilon': 1e-8,
|
||||
'adam_beta1': 0.9,
|
||||
'adam_beta2': 0.999
|
||||
},
|
||||
'checkpoint': {
|
||||
'save_steps': 1000,
|
||||
'save_total_limit': 3,
|
||||
'load_best_model_at_end': True,
|
||||
'metric_for_best_model': 'eval_loss',
|
||||
'greater_is_better': False
|
||||
},
|
||||
'logging': {
|
||||
'log_dir': './logs'
|
||||
}
|
||||
}
|
||||
|
||||
def get(self, key, default=None):
|
||||
keys = key.split('.')
|
||||
value = self._config
|
||||
try:
|
||||
for k in keys:
|
||||
value = value[k]
|
||||
return value
|
||||
except (KeyError, TypeError):
|
||||
return default
|
||||
|
||||
def get_section(self, section):
|
||||
return self._config.get(section, {})
|
||||
|
||||
return FallbackConfig()
|
||||
|
||||
def _apply_config_values(self, config):
|
||||
"""
|
||||
Apply configuration values with proper precedence order.
|
||||
|
||||
Precedence (lowest to highest):
|
||||
1. Dataclass field defaults (already set)
|
||||
2. [training] section values
|
||||
3. [hardware] section values (override training)
|
||||
4. Model-specific sections (handled in subclasses)
|
||||
5. Command-line arguments (handled by training scripts)
|
||||
|
||||
This method handles steps 2-3. Model-specific configs are applied
|
||||
in subclass __post_init__ methods AFTER this base method runs.
|
||||
"""
|
||||
try:
|
||||
# === STEP 2: Apply [training] section (general defaults) ===
|
||||
training_config = config.get_section('training')
|
||||
if training_config:
|
||||
# Only update if value exists in config
|
||||
if 'default_num_epochs' in training_config:
|
||||
old_value = self.num_train_epochs
|
||||
self.num_train_epochs = training_config['default_num_epochs']
|
||||
if old_value != self.num_train_epochs:
|
||||
logger.debug(f"[training] num_train_epochs: {old_value} -> {self.num_train_epochs}")
|
||||
|
||||
if 'default_batch_size' in training_config:
|
||||
self.per_device_train_batch_size = training_config['default_batch_size']
|
||||
logger.debug(f"[training] batch_size set to {self.per_device_train_batch_size}")
|
||||
|
||||
if 'default_learning_rate' in training_config:
|
||||
self.learning_rate = training_config['default_learning_rate']
|
||||
|
||||
if 'default_warmup_ratio' in training_config:
|
||||
self.warmup_ratio = training_config['default_warmup_ratio']
|
||||
|
||||
if 'gradient_accumulation_steps' in training_config:
|
||||
self.gradient_accumulation_steps = training_config['gradient_accumulation_steps']
|
||||
|
||||
if 'default_weight_decay' in training_config:
|
||||
self.weight_decay = training_config['default_weight_decay']
|
||||
|
||||
# === STEP 3: Apply [hardware] section (overrides training) ===
|
||||
hardware_config = config.get_section('hardware')
|
||||
if hardware_config:
|
||||
# Hardware settings take precedence over training defaults
|
||||
if 'dataloader_num_workers' in hardware_config:
|
||||
self.dataloader_num_workers = hardware_config['dataloader_num_workers']
|
||||
|
||||
if 'dataloader_pin_memory' in hardware_config:
|
||||
self.dataloader_pin_memory = hardware_config['dataloader_pin_memory']
|
||||
|
||||
# Batch sizes from hardware override training defaults
|
||||
if 'per_device_train_batch_size' in hardware_config:
|
||||
old_value = self.per_device_train_batch_size
|
||||
self.per_device_train_batch_size = hardware_config['per_device_train_batch_size']
|
||||
if old_value != self.per_device_train_batch_size:
|
||||
logger.info(f"[hardware] overrides batch_size: {old_value} -> {self.per_device_train_batch_size}")
|
||||
|
||||
if 'per_device_eval_batch_size' in hardware_config:
|
||||
self.per_device_eval_batch_size = hardware_config['per_device_eval_batch_size']
|
||||
|
||||
# === Apply other configuration sections ===
|
||||
|
||||
# Model paths
|
||||
if self.cache_dir is None:
|
||||
cache_dir = config.get('model_paths.cache_dir')
|
||||
if cache_dir is not None:
|
||||
self.cache_dir = cache_dir
|
||||
logger.debug(f"cache_dir set from config: {self.cache_dir}")
|
||||
|
||||
# Data configuration
|
||||
data_config = config.get_section('data')
|
||||
if data_config:
|
||||
if 'max_seq_length' in data_config:
|
||||
self.max_seq_length = data_config['max_seq_length']
|
||||
if 'max_query_length' in data_config:
|
||||
self.query_max_length = data_config['max_query_length']
|
||||
if 'max_passage_length' in data_config:
|
||||
self.passage_max_length = data_config['max_passage_length']
|
||||
if 'random_seed' in data_config:
|
||||
self.seed = data_config['random_seed']
|
||||
|
||||
# Optimization configuration
|
||||
opt_config = config.get_section('optimization')
|
||||
if opt_config:
|
||||
if 'fp16' in opt_config:
|
||||
self.fp16 = opt_config['fp16']
|
||||
if 'bf16' in opt_config:
|
||||
self.bf16 = opt_config['bf16']
|
||||
if 'gradient_checkpointing' in opt_config:
|
||||
self.gradient_checkpointing = opt_config['gradient_checkpointing']
|
||||
if 'max_grad_norm' in opt_config:
|
||||
self.max_grad_norm = opt_config['max_grad_norm']
|
||||
if 'adam_epsilon' in opt_config:
|
||||
self.adam_epsilon = opt_config['adam_epsilon']
|
||||
if 'adam_beta1' in opt_config:
|
||||
self.adam_beta1 = opt_config['adam_beta1']
|
||||
if 'adam_beta2' in opt_config:
|
||||
self.adam_beta2 = opt_config['adam_beta2']
|
||||
|
||||
# Checkpoint configuration
|
||||
checkpoint_config = config.get_section('checkpoint')
|
||||
if checkpoint_config:
|
||||
if 'save_steps' in checkpoint_config:
|
||||
self.save_steps = checkpoint_config['save_steps']
|
||||
if 'save_total_limit' in checkpoint_config:
|
||||
self.save_total_limit = checkpoint_config['save_total_limit']
|
||||
if 'load_best_model_at_end' in checkpoint_config:
|
||||
self.load_best_model_at_end = checkpoint_config['load_best_model_at_end']
|
||||
if 'metric_for_best_model' in checkpoint_config:
|
||||
self.metric_for_best_model = checkpoint_config['metric_for_best_model']
|
||||
if 'greater_is_better' in checkpoint_config:
|
||||
self.greater_is_better = checkpoint_config['greater_is_better']
|
||||
|
||||
# Logging configuration
|
||||
logging_config = config.get_section('logging')
|
||||
if logging_config:
|
||||
if 'log_dir' in logging_config:
|
||||
self.logging_dir = logging_config['log_dir']
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Error applying configuration values: {e}")
|
||||
|
||||
def _setup_device(self):
|
||||
"""Enhanced device detection with Ascend NPU support"""
|
||||
if self.device == "auto":
|
||||
device_info = self._detect_available_devices()
|
||||
self.device = device_info['device']
|
||||
self.n_gpu = device_info['count']
|
||||
logger.info(f"Auto-detected device: {self.device} (count: {self.n_gpu})")
|
||||
else:
|
||||
# Manual device specification
|
||||
if self.device == "cuda" and torch.cuda.is_available():
|
||||
self.n_gpu = torch.cuda.device_count()
|
||||
elif self.device == "npu":
|
||||
try:
|
||||
import torch_npu # type: ignore[import]
|
||||
npu = getattr(torch, 'npu', None)
|
||||
if npu is not None and npu.is_available():
|
||||
self.n_gpu = npu.device_count()
|
||||
else:
|
||||
logger.warning("NPU requested but not available, falling back to CPU")
|
||||
self.device = "cpu"
|
||||
self.n_gpu = 0
|
||||
except ImportError:
|
||||
logger.warning("NPU requested but torch_npu not installed, falling back to CPU")
|
||||
self.device = "cpu"
|
||||
self.n_gpu = 0
|
||||
else:
|
||||
self.n_gpu = 0
|
||||
|
||||
def _detect_available_devices(self) -> Dict[str, Any]:
|
||||
"""Comprehensive device detection with priority ordering"""
|
||||
device_priority = ['npu', 'cuda', 'mps', 'cpu']
|
||||
available_devices = []
|
||||
|
||||
# Check Ascend NPU
|
||||
try:
|
||||
import torch_npu # type: ignore[import]
|
||||
npu = getattr(torch, 'npu', None)
|
||||
if npu is not None and npu.is_available():
|
||||
count = npu.device_count()
|
||||
available_devices.append(('npu', count))
|
||||
logger.debug(f"Ascend NPU available: {count} devices")
|
||||
except ImportError:
|
||||
logger.debug("Ascend NPU library not available")
|
||||
|
||||
# Check CUDA
|
||||
if torch.cuda.is_available():
|
||||
count = torch.cuda.device_count()
|
||||
available_devices.append(('cuda', count))
|
||||
logger.debug(f"CUDA available: {count} devices")
|
||||
|
||||
# Check MPS (Apple Silicon)
|
||||
if hasattr(torch, 'backends') and hasattr(torch.backends, 'mps') and hasattr(torch.backends.mps, 'is_available') and torch.backends.mps.is_available(): # type: ignore[attr-defined]
|
||||
available_devices.append(('mps', 1))
|
||||
logger.debug("Apple MPS available")
|
||||
|
||||
# CPU always available
|
||||
available_devices.append(('cpu', 1))
|
||||
|
||||
# Select best available device based on priority
|
||||
for device_type in device_priority:
|
||||
for available_type, count in available_devices:
|
||||
if available_type == device_type:
|
||||
return {'device': device_type, 'count': count}
|
||||
|
||||
return {'device': 'cpu', 'count': 1} # Fallback
|
||||
|
||||
def validate(self):
|
||||
"""Comprehensive configuration validation"""
|
||||
errors = []
|
||||
# Check required paths
|
||||
if self.train_data_path and not os.path.exists(self.train_data_path):
|
||||
errors.append(f"Training data not found: {self.train_data_path}")
|
||||
if self.eval_data_path and not os.path.exists(self.eval_data_path):
|
||||
errors.append(f"Evaluation data not found: {self.eval_data_path}")
|
||||
# Validate numeric parameters
|
||||
if self.learning_rate is None or self.learning_rate <= 0:
|
||||
errors.append("Learning rate must be positive and not None")
|
||||
if self.num_train_epochs is None or self.num_train_epochs <= 0:
|
||||
errors.append("Number of epochs must be positive and not None")
|
||||
if self.per_device_train_batch_size is None or self.per_device_train_batch_size <= 0:
|
||||
errors.append("Training batch size must be positive and not None")
|
||||
if self.gradient_accumulation_steps is None or self.gradient_accumulation_steps <= 0:
|
||||
errors.append("Gradient accumulation steps must be positive and not None")
|
||||
if self.warmup_ratio is None or self.warmup_ratio < 0 or self.warmup_ratio > 1:
|
||||
errors.append("Warmup ratio must be between 0 and 1 and not None")
|
||||
# Validate device
|
||||
if self.device not in ['auto', 'cpu', 'cuda', 'npu', 'mps']:
|
||||
errors.append(f"Invalid device: {self.device}")
|
||||
# Check sequence lengths
|
||||
if self.max_seq_length is None or self.max_seq_length <= 0:
|
||||
errors.append("Maximum sequence length must be positive and not None")
|
||||
if self.query_max_length is None or self.query_max_length <= 0:
|
||||
errors.append("Maximum query length must be positive and not None")
|
||||
if self.passage_max_length is None or self.passage_max_length <= 0:
|
||||
errors.append("Maximum passage length must be positive and not None")
|
||||
if errors:
|
||||
raise ValueError(f"Configuration validation failed:\n" + "\n".join(f" - {error}" for error in errors))
|
||||
logger.info("Configuration validation passed")
|
||||
return True
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert config to dictionary with type handling"""
|
||||
config_dict = {}
|
||||
for key, value in self.__dict__.items():
|
||||
if isinstance(value, (str, int, float, bool, type(None))):
|
||||
config_dict[key] = value
|
||||
elif isinstance(value, (list, tuple)):
|
||||
config_dict[key] = list(value)
|
||||
elif isinstance(value, dict):
|
||||
config_dict[key] = dict(value)
|
||||
else:
|
||||
config_dict[key] = str(value)
|
||||
return config_dict
|
||||
|
||||
def save(self, save_path: str):
|
||||
"""Save configuration to file with error handling"""
|
||||
try:
|
||||
os.makedirs(os.path.dirname(save_path), exist_ok=True)
|
||||
|
||||
import json
|
||||
with open(save_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(self.to_dict(), f, indent=2, ensure_ascii=False)
|
||||
|
||||
logger.info(f"Configuration saved to {save_path}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save configuration: {e}")
|
||||
raise
|
||||
|
||||
@classmethod
|
||||
def load(cls, load_path: str):
|
||||
"""Load configuration from file with validation"""
|
||||
try:
|
||||
import json
|
||||
with open(load_path, 'r', encoding='utf-8') as f:
|
||||
config_dict = json.load(f)
|
||||
|
||||
# Filter only valid field names
|
||||
import inspect
|
||||
sig = inspect.signature(cls)
|
||||
valid_keys = set(sig.parameters.keys())
|
||||
filtered_dict = {k: v for k, v in config_dict.items() if k in valid_keys}
|
||||
|
||||
config = cls(**filtered_dict)
|
||||
config.validate() # Validate after loading
|
||||
|
||||
logger.info(f"Configuration loaded from {load_path}")
|
||||
return config
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load configuration from {load_path}: {e}")
|
||||
raise
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, config_dict: Dict[str, Any]):
|
||||
"""Create config from dictionary with validation"""
|
||||
try:
|
||||
# Filter only valid field names
|
||||
import inspect
|
||||
sig = inspect.signature(cls)
|
||||
valid_keys = set(sig.parameters.keys())
|
||||
filtered_dict = {k: v for k, v in config_dict.items() if k in valid_keys}
|
||||
|
||||
config = cls(**filtered_dict)
|
||||
config.validate() # Validate after creation
|
||||
|
||||
return config
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create configuration from dictionary: {e}")
|
||||
raise
|
||||
|
||||
def update_from_args(self, args):
|
||||
"""Update config from command line arguments with validation"""
|
||||
try:
|
||||
updated_fields = []
|
||||
for key, value in vars(args).items():
|
||||
if hasattr(self, key) and value is not None:
|
||||
old_value = getattr(self, key)
|
||||
setattr(self, key, value)
|
||||
updated_fields.append(f"{key}: {old_value} -> {value}")
|
||||
|
||||
if updated_fields:
|
||||
logger.info("Updated configuration from command line arguments:")
|
||||
for field in updated_fields:
|
||||
logger.info(f" {field}")
|
||||
|
||||
# Validate after updates
|
||||
self.validate()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to update configuration from arguments: {e}")
|
||||
raise
|
||||
|
||||
def get_effective_batch_size(self) -> int:
|
||||
"""Calculate effective batch size considering gradient accumulation"""
|
||||
return self.per_device_train_batch_size * self.gradient_accumulation_steps * max(1, self.n_gpu)
|
||||
|
||||
def get_training_steps(self, dataset_size: int) -> int:
|
||||
"""Calculate total training steps"""
|
||||
steps_per_epoch = dataset_size // self.get_effective_batch_size()
|
||||
return steps_per_epoch * self.num_train_epochs
|
||||
|
||||
def print_effective_config(self):
|
||||
"""Print the effective value and its source for each parameter."""
|
||||
print("Parameter | Value | Source")
|
||||
print("----------|-------|-------")
|
||||
for field in self.__dataclass_fields__:
|
||||
value = getattr(self, field)
|
||||
# For now, just print the value; source tracking can be added if needed
|
||||
print(f"{field} | {value} | (see docstring for precedence)")
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""String representation of configuration"""
|
||||
return f"BaseConfig(device={self.device}, batch_size={self.per_device_train_batch_size}, lr={self.learning_rate})"
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return self.__str__()
|
||||
|
||||
def get_device(self):
|
||||
"""Return the device string for training (cuda, npu, mps, cpu)"""
|
||||
if torch.cuda.is_available():
|
||||
return "cuda"
|
||||
npu = getattr(torch, "npu", None)
|
||||
if npu is not None and hasattr(npu, "is_available") and npu.is_available():
|
||||
return "npu"
|
||||
backends = getattr(torch, "backends", None)
|
||||
if backends is not None and hasattr(backends, "mps") and backends.mps.is_available():
|
||||
return "mps"
|
||||
return "cpu"
|
||||
|
||||
def is_npu(self):
|
||||
"""Return True if NPU is available and selected as device."""
|
||||
npu = getattr(torch, "npu", None)
|
||||
if npu is not None and hasattr(npu, "is_available") and npu.is_available():
|
||||
device = self.get_device() if hasattr(self, 'get_device') else "cpu"
|
||||
return device == "npu"
|
||||
return False
|
||||
254
config/m3_config.py
Normal file
254
config/m3_config.py
Normal file
@@ -0,0 +1,254 @@
|
||||
"""
|
||||
BGE-M3 specific configuration
|
||||
"""
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional, List, Dict
|
||||
from .base_config import BaseConfig
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class BGEM3Config(BaseConfig):
|
||||
"""
|
||||
Configuration for BGE-M3 model and trainer.
|
||||
"""
|
||||
use_in_batch_negatives: bool = False # Use in-batch negatives for contrastive loss
|
||||
teacher_temperature: float = 1.0 # Temperature for teacher distillation
|
||||
kd_loss_weight: float = 1.0 # Weight for KD loss
|
||||
kd_loss_type: str = "mse" # Type of KD loss (e.g., mse, kl)
|
||||
hard_negative_margin: float = 0.0 # Margin for hard negative mining
|
||||
"""Configuration for BGE-M3 embedding model fine-tuning
|
||||
|
||||
Attributes:
|
||||
model_name_or_path: Path or identifier for the pretrained BGE-M3 model.
|
||||
use_dense: Enable dense embedding head.
|
||||
use_sparse: Enable sparse (SPLADE-style) embedding head.
|
||||
use_colbert: Enable ColBERT multi-vector head.
|
||||
unified_finetuning: Unified fine-tuning for all functionalities.
|
||||
sentence_pooling_method: Pooling method for sentence embeddings (e.g., 'cls').
|
||||
normalize_embeddings: Whether to normalize output embeddings.
|
||||
colbert_dim: Output dimension for ColBERT head (-1 uses model hidden size).
|
||||
sparse_top_k: Top-k for sparse representations.
|
||||
temperature: InfoNCE loss temperature.
|
||||
train_group_size: Number of passages per query in training.
|
||||
negatives_cross_device: Share negatives across devices/GPUs.
|
||||
use_hard_negatives: Enable hard negative mining.
|
||||
num_hard_negatives: Number of hard negatives to mine.
|
||||
hard_negative_mining_steps: Steps between hard negative mining refreshes.
|
||||
ance_negative_range: Range for ANCE negative mining.
|
||||
ance_margin: Margin for ANCE mining.
|
||||
use_self_distill: Enable self-knowledge distillation.
|
||||
self_distill_temperature: Temperature for self-distillation.
|
||||
self_distill_alpha: Alpha (weight) for self-distillation loss.
|
||||
use_query_instruction: Add instruction to queries.
|
||||
query_instruction_for_retrieval: Instruction for retrieval queries.
|
||||
query_instruction_for_reranking: Instruction for reranking queries.
|
||||
augment_queries: Enable query augmentation.
|
||||
augmentation_methods: List of augmentation methods to use.
|
||||
create_hard_negatives: Create hard negatives via augmentation.
|
||||
dense_weight: Loss weight for dense head.
|
||||
sparse_weight: Loss weight for sparse head.
|
||||
colbert_weight: Loss weight for ColBERT head.
|
||||
distillation_weight: Loss weight for distillation.
|
||||
eval_strategy: Evaluation strategy ('epoch', 'steps').
|
||||
metric_for_best_model: Metric to select best model.
|
||||
greater_is_better: Whether higher metric is better.
|
||||
use_gradient_checkpointing: Enable gradient checkpointing.
|
||||
max_passages_per_query: Max passages per query in memory.
|
||||
max_query_length: Max query length.
|
||||
max_passage_length: Max passage length.
|
||||
max_length: Max total input length (multi-granularity).
|
||||
"""
|
||||
|
||||
# Model specific
|
||||
model_name_or_path: str = "BAAI/bge-m3"
|
||||
|
||||
# Multi-functionality settings
|
||||
use_dense: bool = True
|
||||
use_sparse: bool = True
|
||||
use_colbert: bool = True
|
||||
unified_finetuning: bool = True
|
||||
|
||||
# Architecture settings
|
||||
sentence_pooling_method: str = "cls" # BGE-M3 uses CLS pooling
|
||||
normalize_embeddings: bool = True
|
||||
colbert_dim: int = -1 # -1 means use model hidden size
|
||||
sparse_top_k: int = 100 # Top-k for sparse representations
|
||||
|
||||
# Training specific
|
||||
temperature: float = 0.02 # InfoNCE temperature (from BGE-M3 paper)
|
||||
train_group_size: int = 8 # Number of passages per query
|
||||
negatives_cross_device: bool = True # Share negatives across GPUs
|
||||
|
||||
# Hard negative mining
|
||||
use_hard_negatives: bool = True
|
||||
num_hard_negatives: int = 15
|
||||
hard_negative_mining_steps: int = 1000
|
||||
ance_negative_range: tuple = (10, 100)
|
||||
ance_margin: float = 0.1
|
||||
|
||||
# Self-distillation
|
||||
use_self_distill: bool = False # Changed default to False - make KD opt-in
|
||||
self_distill_temperature: float = 4.0
|
||||
self_distill_alpha: float = 0.5
|
||||
|
||||
# Query instruction
|
||||
use_query_instruction: bool = False
|
||||
query_instruction_for_retrieval: str = "Represent this sentence for searching relevant passages:"
|
||||
query_instruction_for_reranking: str = ""
|
||||
|
||||
# Data augmentation
|
||||
augment_queries: bool = False
|
||||
augmentation_methods: List[str] = field(default_factory=lambda: ["paraphrase", "back_translation"])
|
||||
create_hard_negatives: bool = False
|
||||
|
||||
# Loss weights for multi-task learning
|
||||
dense_weight: float = 1.0
|
||||
sparse_weight: float = 0.3
|
||||
colbert_weight: float = 0.5
|
||||
distillation_weight: float = 1.0
|
||||
|
||||
# Evaluation
|
||||
eval_strategy: str = "epoch"
|
||||
metric_for_best_model: str = "eval_recall@10"
|
||||
greater_is_better: bool = True
|
||||
|
||||
# Memory optimization
|
||||
use_gradient_checkpointing: bool = True
|
||||
max_passages_per_query: int = 32
|
||||
|
||||
# Specific to M3's multi-granularity
|
||||
max_query_length: int = 64
|
||||
max_passage_length: int = 512
|
||||
max_length: int = 8192 # BGE-M3 supports up to 8192 tokens
|
||||
|
||||
def __post_init__(self):
|
||||
"""Load model-specific parameters from [m3] section in config.toml"""
|
||||
super().__post_init__()
|
||||
try:
|
||||
from utils.config_loader import get_config
|
||||
except ImportError:
|
||||
from bge_finetune.utils.config_loader import get_config # type: ignore[import]
|
||||
config = get_config()
|
||||
self._apply_m3_config(config)
|
||||
|
||||
def _apply_m3_config(self, config):
|
||||
"""
|
||||
Apply BGE-M3 specific configuration from [m3] section in config.toml.
|
||||
"""
|
||||
m3_config = config.get_section('m3')
|
||||
if not m3_config:
|
||||
logger.warning("No [m3] section found in config.toml")
|
||||
return
|
||||
|
||||
# Helper to convert tuple from config to list if needed
|
||||
def get_val(key, default):
|
||||
if key not in m3_config:
|
||||
return default
|
||||
v = m3_config[key]
|
||||
if isinstance(default, list) and isinstance(v, tuple):
|
||||
return list(v)
|
||||
return v
|
||||
|
||||
# Only update if key exists in config
|
||||
if 'use_dense' in m3_config:
|
||||
self.use_dense = m3_config['use_dense']
|
||||
if 'use_sparse' in m3_config:
|
||||
self.use_sparse = m3_config['use_sparse']
|
||||
if 'use_colbert' in m3_config:
|
||||
self.use_colbert = m3_config['use_colbert']
|
||||
if 'unified_finetuning' in m3_config:
|
||||
self.unified_finetuning = m3_config['unified_finetuning']
|
||||
if 'sentence_pooling_method' in m3_config:
|
||||
self.sentence_pooling_method = m3_config['sentence_pooling_method']
|
||||
if 'normalize_embeddings' in m3_config:
|
||||
self.normalize_embeddings = m3_config['normalize_embeddings']
|
||||
if 'colbert_dim' in m3_config:
|
||||
self.colbert_dim = m3_config['colbert_dim']
|
||||
if 'sparse_top_k' in m3_config:
|
||||
self.sparse_top_k = m3_config['sparse_top_k']
|
||||
if 'temperature' in m3_config:
|
||||
self.temperature = m3_config['temperature']
|
||||
if 'train_group_size' in m3_config:
|
||||
self.train_group_size = m3_config['train_group_size']
|
||||
if 'negatives_cross_device' in m3_config:
|
||||
self.negatives_cross_device = m3_config['negatives_cross_device']
|
||||
if 'use_hard_negatives' in m3_config:
|
||||
self.use_hard_negatives = m3_config['use_hard_negatives']
|
||||
if 'num_hard_negatives' in m3_config:
|
||||
self.num_hard_negatives = m3_config['num_hard_negatives']
|
||||
if 'hard_negative_mining_steps' in m3_config:
|
||||
self.hard_negative_mining_steps = m3_config['hard_negative_mining_steps']
|
||||
if 'ance_negative_range' in m3_config:
|
||||
self.ance_negative_range = tuple(get_val('ance_negative_range', self.ance_negative_range))
|
||||
if 'augmentation_methods' in m3_config:
|
||||
self.augmentation_methods = list(get_val('augmentation_methods', self.augmentation_methods))
|
||||
if 'ance_margin' in m3_config:
|
||||
self.ance_margin = m3_config['ance_margin']
|
||||
if 'use_self_distill' in m3_config:
|
||||
self.use_self_distill = m3_config['use_self_distill']
|
||||
if 'self_distill_temperature' in m3_config:
|
||||
self.self_distill_temperature = m3_config['self_distill_temperature']
|
||||
if 'self_distill_alpha' in m3_config:
|
||||
self.self_distill_alpha = m3_config['self_distill_alpha']
|
||||
if 'use_query_instruction' in m3_config:
|
||||
self.use_query_instruction = m3_config['use_query_instruction']
|
||||
if 'query_instruction_for_retrieval' in m3_config:
|
||||
self.query_instruction_for_retrieval = m3_config['query_instruction_for_retrieval']
|
||||
if 'query_instruction_for_reranking' in m3_config:
|
||||
self.query_instruction_for_reranking = m3_config['query_instruction_for_reranking']
|
||||
if 'augment_queries' in m3_config:
|
||||
self.augment_queries = m3_config['augment_queries']
|
||||
if 'create_hard_negatives' in m3_config:
|
||||
self.create_hard_negatives = m3_config['create_hard_negatives']
|
||||
if 'dense_weight' in m3_config:
|
||||
self.dense_weight = m3_config['dense_weight']
|
||||
if 'sparse_weight' in m3_config:
|
||||
self.sparse_weight = m3_config['sparse_weight']
|
||||
if 'colbert_weight' in m3_config:
|
||||
self.colbert_weight = m3_config['colbert_weight']
|
||||
if 'distillation_weight' in m3_config:
|
||||
self.distillation_weight = m3_config['distillation_weight']
|
||||
if 'metric_for_best_model' in m3_config:
|
||||
self.metric_for_best_model = m3_config['metric_for_best_model']
|
||||
if 'greater_is_better' in m3_config:
|
||||
self.greater_is_better = m3_config['greater_is_better']
|
||||
if 'use_gradient_checkpointing' in m3_config:
|
||||
self.use_gradient_checkpointing = m3_config['use_gradient_checkpointing']
|
||||
if 'max_passages_per_query' in m3_config:
|
||||
self.max_passages_per_query = m3_config['max_passages_per_query']
|
||||
if 'max_query_length' in m3_config:
|
||||
self.max_query_length = m3_config['max_query_length']
|
||||
if 'max_passage_length' in m3_config:
|
||||
self.max_passage_length = m3_config['max_passage_length']
|
||||
if 'max_length' in m3_config:
|
||||
self.max_length = m3_config['max_length']
|
||||
|
||||
# Log parameter overrides
|
||||
logger.info("Loaded BGE-M3 config from [m3] section in config.toml")
|
||||
|
||||
# Check for required parameters
|
||||
if self.temperature is None or self.temperature <= 0:
|
||||
raise ValueError("Temperature must be positive and not None")
|
||||
if self.train_group_size is None or self.train_group_size < 2:
|
||||
raise ValueError("train_group_size must be at least 2 and not None")
|
||||
|
||||
def validate(self):
|
||||
"""Validate configuration settings"""
|
||||
if not any([self.use_dense, self.use_sparse, self.use_colbert]):
|
||||
raise ValueError("At least one of use_dense, use_sparse, or use_colbert must be True")
|
||||
|
||||
if self.unified_finetuning and not all([self.use_dense, self.use_sparse, self.use_colbert]):
|
||||
logger.warning("unified_finetuning is True but not all functionalities are enabled")
|
||||
|
||||
if self.temperature <= 0:
|
||||
raise ValueError("Temperature must be positive")
|
||||
|
||||
if self.train_group_size < 2:
|
||||
raise ValueError("train_group_size must be at least 2")
|
||||
|
||||
if self.max_length > 8192:
|
||||
logger.warning("max_length > 8192 may not be supported by BGE-M3")
|
||||
|
||||
return True
|
||||
264
config/reranker_config.py
Normal file
264
config/reranker_config.py
Normal file
@@ -0,0 +1,264 @@
|
||||
"""
|
||||
BGE-Reranker specific configuration
|
||||
"""
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional, List, Dict
|
||||
from .base_config import BaseConfig
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class BGERerankerConfig(BaseConfig):
|
||||
"""
|
||||
Configuration for BGE-Reranker model and trainer.
|
||||
"""
|
||||
num_labels: int = 1 # Number of output labels
|
||||
listwise_temperature: float = 1.0 # Temperature for listwise loss
|
||||
kd_temperature: float = 1.0 # Temperature for KD loss
|
||||
kd_loss_weight: float = 1.0 # Weight for KD loss
|
||||
denoise_confidence_threshold: float = 0.0 # Threshold for denoising
|
||||
"""Configuration for BGE-Reranker cross-encoder fine-tuning
|
||||
|
||||
Attributes:
|
||||
model_name_or_path: Path or identifier for the pretrained BGE-Reranker model.
|
||||
add_pooling_layer: Whether to add a pooling layer (not used for cross-encoder).
|
||||
use_fp16: Use mixed precision (float16) training.
|
||||
train_group_size: Number of passages per query for listwise training.
|
||||
loss_type: Loss function type ('listwise_ce', 'pairwise_margin', 'combined').
|
||||
margin: Margin for pairwise loss.
|
||||
temperature: Temperature for listwise CE loss.
|
||||
use_knowledge_distillation: Enable knowledge distillation from teacher model.
|
||||
teacher_model_name_or_path: Path to teacher model for distillation.
|
||||
distillation_temperature: Temperature for distillation.
|
||||
distillation_alpha: Alpha (weight) for distillation loss.
|
||||
use_hard_negatives: Enable hard negative mining.
|
||||
hard_negative_ratio: Ratio of hard to random negatives.
|
||||
use_denoising: Enable denoising strategy.
|
||||
noise_probability: Probability of noise for denoising.
|
||||
max_pairs_per_device: Max query-passage pairs per device.
|
||||
accumulate_gradients_from_all_devices: Accumulate gradients from all devices.
|
||||
max_seq_length: Max combined query+passage length.
|
||||
query_max_length: Max query length (if separate encoding).
|
||||
passage_max_length: Max passage length (if separate encoding).
|
||||
shuffle_passages: Shuffle passage order during training.
|
||||
position_bias_cutoff: Top-k passages to consider for position bias.
|
||||
use_gradient_caching: Enable gradient caching for efficiency.
|
||||
gradient_cache_batch_size: Batch size for gradient caching.
|
||||
eval_group_size: Number of candidates during evaluation.
|
||||
metric_for_best_model: Metric to select best model.
|
||||
eval_normalize_scores: Whether to normalize scores during evaluation.
|
||||
listwise_weight: Loss weight for listwise loss.
|
||||
pairwise_weight: Loss weight for pairwise loss.
|
||||
filter_empty_passages: Filter out empty passages.
|
||||
min_passage_length: Minimum passage length to keep.
|
||||
max_passage_length_ratio: Max ratio between query and passage length.
|
||||
gradient_checkpointing: Enable gradient checkpointing.
|
||||
recompute_scores: Recompute scores instead of storing all.
|
||||
use_dynamic_teacher: Enable dynamic teacher for RocketQAv2-style training.
|
||||
teacher_update_steps: Steps between teacher updates.
|
||||
teacher_momentum: Momentum for teacher updates.
|
||||
"""
|
||||
|
||||
# Model specific
|
||||
model_name_or_path: str = "BAAI/bge-reranker-base"
|
||||
|
||||
# Architecture settings
|
||||
add_pooling_layer: bool = False # Cross-encoder doesn't use pooling
|
||||
use_fp16: bool = True # Use mixed precision by default
|
||||
|
||||
# Training specific
|
||||
train_group_size: int = 16 # Number of passages per query for listwise training
|
||||
loss_type: str = "listwise_ce" # "listwise_ce", "pairwise_margin", "combined"
|
||||
margin: float = 1.0 # For pairwise margin loss
|
||||
temperature: float = 1.0 # For listwise CE loss
|
||||
|
||||
# Knowledge distillation
|
||||
use_knowledge_distillation: bool = False
|
||||
teacher_model_name_or_path: Optional[str] = None
|
||||
distillation_temperature: float = 3.0
|
||||
distillation_alpha: float = 0.7
|
||||
|
||||
# Hard negative strategies
|
||||
use_hard_negatives: bool = True
|
||||
hard_negative_ratio: float = 0.8 # Ratio of hard to random negatives
|
||||
use_denoising: bool = False # Denoising strategy from paper
|
||||
noise_probability: float = 0.1
|
||||
|
||||
# Listwise training specific
|
||||
max_pairs_per_device: int = 128 # Maximum query-passage pairs per device
|
||||
accumulate_gradients_from_all_devices: bool = True
|
||||
|
||||
# Cross-encoder specific settings
|
||||
max_seq_length: int = 512 # Maximum combined query+passage length
|
||||
query_max_length: int = 64 # For separate encoding (if needed)
|
||||
passage_max_length: int = 512 # For separate encoding (if needed)
|
||||
|
||||
# Position bias mitigation
|
||||
shuffle_passages: bool = True # Shuffle passage order during training
|
||||
position_bias_cutoff: int = 10 # Consider position bias for top-k passages
|
||||
|
||||
# Advanced training strategies
|
||||
use_gradient_caching: bool = True # Cache gradients for efficient training
|
||||
gradient_cache_batch_size: int = 32
|
||||
|
||||
# Evaluation specific
|
||||
eval_group_size: int = 100 # Number of candidates during evaluation
|
||||
metric_for_best_model: str = "eval_mrr@10"
|
||||
eval_normalize_scores: bool = True
|
||||
|
||||
# Loss combination weights (for combined loss)
|
||||
listwise_weight: float = 0.7
|
||||
pairwise_weight: float = 0.3
|
||||
|
||||
# Data filtering
|
||||
filter_empty_passages: bool = True
|
||||
min_passage_length: int = 10
|
||||
max_passage_length_ratio: float = 50.0 # Max ratio between query and passage length
|
||||
|
||||
# Memory optimization for large group sizes
|
||||
gradient_checkpointing: bool = True
|
||||
recompute_scores: bool = False # Recompute instead of storing all scores
|
||||
|
||||
# Dynamic teacher (for RocketQAv2-style training)
|
||||
use_dynamic_teacher: bool = False
|
||||
teacher_update_steps: int = 1000
|
||||
teacher_momentum: float = 0.999
|
||||
|
||||
def __post_init__(self):
|
||||
"""Load model-specific parameters from [reranker] section in config.toml"""
|
||||
super().__post_init__()
|
||||
try:
|
||||
from utils.config_loader import get_config
|
||||
except ImportError:
|
||||
from bge_finetune.utils.config_loader import get_config # type: ignore[import]
|
||||
config = get_config()
|
||||
self._apply_reranker_config(config)
|
||||
|
||||
def _apply_reranker_config(self, config):
|
||||
"""
|
||||
Apply BGE-Reranker specific configuration from [reranker] section in config.toml.
|
||||
"""
|
||||
reranker_config = config.get_section('reranker')
|
||||
if not reranker_config:
|
||||
logger.warning("No [reranker] section found in config.toml")
|
||||
return
|
||||
|
||||
# Helper to convert tuple from config to list if needed
|
||||
def get_val(key, default):
|
||||
if key not in reranker_config:
|
||||
return default
|
||||
v = reranker_config[key]
|
||||
if isinstance(default, tuple) and isinstance(v, list):
|
||||
return tuple(v)
|
||||
if isinstance(default, list) and isinstance(v, tuple):
|
||||
return list(v)
|
||||
return v
|
||||
|
||||
# Only update if key exists in config
|
||||
if 'add_pooling_layer' in reranker_config:
|
||||
self.add_pooling_layer = reranker_config['add_pooling_layer']
|
||||
if 'use_fp16' in reranker_config:
|
||||
self.use_fp16 = reranker_config['use_fp16']
|
||||
if 'train_group_size' in reranker_config:
|
||||
self.train_group_size = reranker_config['train_group_size']
|
||||
if 'loss_type' in reranker_config:
|
||||
self.loss_type = reranker_config['loss_type']
|
||||
if 'margin' in reranker_config:
|
||||
self.margin = reranker_config['margin']
|
||||
if 'temperature' in reranker_config:
|
||||
self.temperature = reranker_config['temperature']
|
||||
if 'use_knowledge_distillation' in reranker_config:
|
||||
self.use_knowledge_distillation = reranker_config['use_knowledge_distillation']
|
||||
if 'teacher_model_name_or_path' in reranker_config:
|
||||
self.teacher_model_name_or_path = reranker_config['teacher_model_name_or_path']
|
||||
if 'distillation_temperature' in reranker_config:
|
||||
self.distillation_temperature = reranker_config['distillation_temperature']
|
||||
if 'distillation_alpha' in reranker_config:
|
||||
self.distillation_alpha = reranker_config['distillation_alpha']
|
||||
if 'use_hard_negatives' in reranker_config:
|
||||
self.use_hard_negatives = reranker_config['use_hard_negatives']
|
||||
if 'hard_negative_ratio' in reranker_config:
|
||||
self.hard_negative_ratio = reranker_config['hard_negative_ratio']
|
||||
if 'use_denoising' in reranker_config:
|
||||
self.use_denoising = reranker_config['use_denoising']
|
||||
if 'noise_probability' in reranker_config:
|
||||
self.noise_probability = reranker_config['noise_probability']
|
||||
if 'max_pairs_per_device' in reranker_config:
|
||||
self.max_pairs_per_device = reranker_config['max_pairs_per_device']
|
||||
if 'accumulate_gradients_from_all_devices' in reranker_config:
|
||||
self.accumulate_gradients_from_all_devices = reranker_config['accumulate_gradients_from_all_devices']
|
||||
if 'max_seq_length' in reranker_config:
|
||||
self.max_seq_length = reranker_config['max_seq_length']
|
||||
if 'query_max_length' in reranker_config:
|
||||
self.query_max_length = reranker_config['query_max_length']
|
||||
if 'passage_max_length' in reranker_config:
|
||||
self.passage_max_length = reranker_config['passage_max_length']
|
||||
if 'shuffle_passages' in reranker_config:
|
||||
self.shuffle_passages = reranker_config['shuffle_passages']
|
||||
if 'position_bias_cutoff' in reranker_config:
|
||||
self.position_bias_cutoff = reranker_config['position_bias_cutoff']
|
||||
if 'use_gradient_caching' in reranker_config:
|
||||
self.use_gradient_caching = reranker_config['use_gradient_caching']
|
||||
if 'gradient_cache_batch_size' in reranker_config:
|
||||
self.gradient_cache_batch_size = reranker_config['gradient_cache_batch_size']
|
||||
if 'eval_group_size' in reranker_config:
|
||||
self.eval_group_size = reranker_config['eval_group_size']
|
||||
if 'metric_for_best_model' in reranker_config:
|
||||
self.metric_for_best_model = reranker_config['metric_for_best_model']
|
||||
if 'eval_normalize_scores' in reranker_config:
|
||||
self.eval_normalize_scores = reranker_config['eval_normalize_scores']
|
||||
if 'listwise_weight' in reranker_config:
|
||||
self.listwise_weight = reranker_config['listwise_weight']
|
||||
if 'pairwise_weight' in reranker_config:
|
||||
self.pairwise_weight = reranker_config['pairwise_weight']
|
||||
if 'filter_empty_passages' in reranker_config:
|
||||
self.filter_empty_passages = reranker_config['filter_empty_passages']
|
||||
if 'min_passage_length' in reranker_config:
|
||||
self.min_passage_length = reranker_config['min_passage_length']
|
||||
if 'max_passage_length_ratio' in reranker_config:
|
||||
self.max_passage_length_ratio = reranker_config['max_passage_length_ratio']
|
||||
if 'gradient_checkpointing' in reranker_config:
|
||||
self.gradient_checkpointing = reranker_config['gradient_checkpointing']
|
||||
if 'recompute_scores' in reranker_config:
|
||||
self.recompute_scores = reranker_config['recompute_scores']
|
||||
if 'use_dynamic_teacher' in reranker_config:
|
||||
self.use_dynamic_teacher = reranker_config['use_dynamic_teacher']
|
||||
if 'teacher_update_steps' in reranker_config:
|
||||
self.teacher_update_steps = reranker_config['teacher_update_steps']
|
||||
if 'teacher_momentum' in reranker_config:
|
||||
self.teacher_momentum = reranker_config['teacher_momentum']
|
||||
|
||||
# Log parameter overrides
|
||||
logger.info("Loaded BGE-Reranker config from [reranker] section in config.toml")
|
||||
|
||||
# Check for required parameters
|
||||
if self.loss_type not in ["listwise_ce", "pairwise_margin", "combined", "cross_entropy", "binary_cross_entropy"]:
|
||||
raise ValueError("loss_type must be one of ['listwise_ce', 'pairwise_margin', 'combined', 'cross_entropy', 'binary_cross_entropy']")
|
||||
if self.train_group_size is None or self.train_group_size < 2:
|
||||
raise ValueError("train_group_size must be at least 2 and not None for reranking")
|
||||
if self.use_knowledge_distillation and not self.teacher_model_name_or_path:
|
||||
raise ValueError("teacher_model_name_or_path required when use_knowledge_distillation=True")
|
||||
|
||||
def validate(self):
|
||||
"""Validate configuration settings"""
|
||||
valid_loss_types = ["listwise_ce", "pairwise_margin", "combined"]
|
||||
if self.loss_type not in valid_loss_types:
|
||||
raise ValueError(f"loss_type must be one of {valid_loss_types}")
|
||||
|
||||
if self.train_group_size < 2:
|
||||
raise ValueError("train_group_size must be at least 2 for reranking")
|
||||
|
||||
if self.use_knowledge_distillation and not self.teacher_model_name_or_path:
|
||||
raise ValueError("teacher_model_name_or_path required when use_knowledge_distillation=True")
|
||||
|
||||
if self.hard_negative_ratio < 0 or self.hard_negative_ratio > 1:
|
||||
raise ValueError("hard_negative_ratio must be between 0 and 1")
|
||||
|
||||
if self.max_seq_length > 512 and "base" in self.model_name_or_path:
|
||||
logger.warning("max_seq_length > 512 may not be optimal for base models")
|
||||
|
||||
if self.gradient_cache_batch_size > self.per_device_train_batch_size:
|
||||
logger.warning("gradient_cache_batch_size > per_device_train_batch_size may not improve efficiency")
|
||||
|
||||
return True
|
||||
Reference in New Issue
Block a user