""" BGE-Reranker specific configuration """ from dataclasses import dataclass, field from typing import Optional, List, Dict from .base_config import BaseConfig import logging logger = logging.getLogger(__name__) @dataclass class BGERerankerConfig(BaseConfig): """ Configuration for BGE-Reranker model and trainer. """ num_labels: int = 1 # Number of output labels listwise_temperature: float = 1.0 # Temperature for listwise loss kd_temperature: float = 1.0 # Temperature for KD loss kd_loss_weight: float = 1.0 # Weight for KD loss denoise_confidence_threshold: float = 0.0 # Threshold for denoising """Configuration for BGE-Reranker cross-encoder fine-tuning Attributes: model_name_or_path: Path or identifier for the pretrained BGE-Reranker model. add_pooling_layer: Whether to add a pooling layer (not used for cross-encoder). use_fp16: Use mixed precision (float16) training. train_group_size: Number of passages per query for listwise training. loss_type: Loss function type ('listwise_ce', 'pairwise_margin', 'combined'). margin: Margin for pairwise loss. temperature: Temperature for listwise CE loss. use_knowledge_distillation: Enable knowledge distillation from teacher model. teacher_model_name_or_path: Path to teacher model for distillation. distillation_temperature: Temperature for distillation. distillation_alpha: Alpha (weight) for distillation loss. use_hard_negatives: Enable hard negative mining. hard_negative_ratio: Ratio of hard to random negatives. use_denoising: Enable denoising strategy. noise_probability: Probability of noise for denoising. max_pairs_per_device: Max query-passage pairs per device. accumulate_gradients_from_all_devices: Accumulate gradients from all devices. max_seq_length: Max combined query+passage length. query_max_length: Max query length (if separate encoding). passage_max_length: Max passage length (if separate encoding). shuffle_passages: Shuffle passage order during training. position_bias_cutoff: Top-k passages to consider for position bias. use_gradient_caching: Enable gradient caching for efficiency. gradient_cache_batch_size: Batch size for gradient caching. eval_group_size: Number of candidates during evaluation. metric_for_best_model: Metric to select best model. eval_normalize_scores: Whether to normalize scores during evaluation. listwise_weight: Loss weight for listwise loss. pairwise_weight: Loss weight for pairwise loss. filter_empty_passages: Filter out empty passages. min_passage_length: Minimum passage length to keep. max_passage_length_ratio: Max ratio between query and passage length. gradient_checkpointing: Enable gradient checkpointing. recompute_scores: Recompute scores instead of storing all. use_dynamic_teacher: Enable dynamic teacher for RocketQAv2-style training. teacher_update_steps: Steps between teacher updates. teacher_momentum: Momentum for teacher updates. """ # Model specific model_name_or_path: str = "BAAI/bge-reranker-base" # Architecture settings add_pooling_layer: bool = False # Cross-encoder doesn't use pooling use_fp16: bool = True # Use mixed precision by default # Training specific train_group_size: int = 16 # Number of passages per query for listwise training loss_type: str = "listwise_ce" # "listwise_ce", "pairwise_margin", "combined" margin: float = 1.0 # For pairwise margin loss temperature: float = 1.0 # For listwise CE loss # Knowledge distillation use_knowledge_distillation: bool = False teacher_model_name_or_path: Optional[str] = None distillation_temperature: float = 3.0 distillation_alpha: float = 0.7 # Hard negative strategies use_hard_negatives: bool = True hard_negative_ratio: float = 0.8 # Ratio of hard to random negatives use_denoising: bool = False # Denoising strategy from paper noise_probability: float = 0.1 # Listwise training specific max_pairs_per_device: int = 128 # Maximum query-passage pairs per device accumulate_gradients_from_all_devices: bool = True # Cross-encoder specific settings max_seq_length: int = 512 # Maximum combined query+passage length query_max_length: int = 64 # For separate encoding (if needed) passage_max_length: int = 512 # For separate encoding (if needed) # Position bias mitigation shuffle_passages: bool = True # Shuffle passage order during training position_bias_cutoff: int = 10 # Consider position bias for top-k passages # Advanced training strategies use_gradient_caching: bool = True # Cache gradients for efficient training gradient_cache_batch_size: int = 32 # Evaluation specific eval_group_size: int = 100 # Number of candidates during evaluation metric_for_best_model: str = "eval_mrr@10" eval_normalize_scores: bool = True # Loss combination weights (for combined loss) listwise_weight: float = 0.7 pairwise_weight: float = 0.3 # Data filtering filter_empty_passages: bool = True min_passage_length: int = 10 max_passage_length_ratio: float = 50.0 # Max ratio between query and passage length # Memory optimization for large group sizes gradient_checkpointing: bool = True recompute_scores: bool = False # Recompute instead of storing all scores # Dynamic teacher (for RocketQAv2-style training) use_dynamic_teacher: bool = False teacher_update_steps: int = 1000 teacher_momentum: float = 0.999 def __post_init__(self): """Load model-specific parameters from [reranker] section in config.toml""" super().__post_init__() try: from utils.config_loader import get_config except ImportError: from bge_finetune.utils.config_loader import get_config # type: ignore[import] config = get_config() self._apply_reranker_config(config) def _apply_reranker_config(self, config): """ Apply BGE-Reranker specific configuration from [reranker] section in config.toml. """ reranker_config = config.get_section('reranker') if not reranker_config: logger.warning("No [reranker] section found in config.toml") return # Helper to convert tuple from config to list if needed def get_val(key, default): if key not in reranker_config: return default v = reranker_config[key] if isinstance(default, tuple) and isinstance(v, list): return tuple(v) if isinstance(default, list) and isinstance(v, tuple): return list(v) return v # Only update if key exists in config if 'add_pooling_layer' in reranker_config: self.add_pooling_layer = reranker_config['add_pooling_layer'] if 'use_fp16' in reranker_config: self.use_fp16 = reranker_config['use_fp16'] if 'train_group_size' in reranker_config: self.train_group_size = reranker_config['train_group_size'] if 'loss_type' in reranker_config: self.loss_type = reranker_config['loss_type'] if 'margin' in reranker_config: self.margin = reranker_config['margin'] if 'temperature' in reranker_config: self.temperature = reranker_config['temperature'] if 'use_knowledge_distillation' in reranker_config: self.use_knowledge_distillation = reranker_config['use_knowledge_distillation'] if 'teacher_model_name_or_path' in reranker_config: self.teacher_model_name_or_path = reranker_config['teacher_model_name_or_path'] if 'distillation_temperature' in reranker_config: self.distillation_temperature = reranker_config['distillation_temperature'] if 'distillation_alpha' in reranker_config: self.distillation_alpha = reranker_config['distillation_alpha'] if 'use_hard_negatives' in reranker_config: self.use_hard_negatives = reranker_config['use_hard_negatives'] if 'hard_negative_ratio' in reranker_config: self.hard_negative_ratio = reranker_config['hard_negative_ratio'] if 'use_denoising' in reranker_config: self.use_denoising = reranker_config['use_denoising'] if 'noise_probability' in reranker_config: self.noise_probability = reranker_config['noise_probability'] if 'max_pairs_per_device' in reranker_config: self.max_pairs_per_device = reranker_config['max_pairs_per_device'] if 'accumulate_gradients_from_all_devices' in reranker_config: self.accumulate_gradients_from_all_devices = reranker_config['accumulate_gradients_from_all_devices'] if 'max_seq_length' in reranker_config: self.max_seq_length = reranker_config['max_seq_length'] if 'query_max_length' in reranker_config: self.query_max_length = reranker_config['query_max_length'] if 'passage_max_length' in reranker_config: self.passage_max_length = reranker_config['passage_max_length'] if 'shuffle_passages' in reranker_config: self.shuffle_passages = reranker_config['shuffle_passages'] if 'position_bias_cutoff' in reranker_config: self.position_bias_cutoff = reranker_config['position_bias_cutoff'] if 'use_gradient_caching' in reranker_config: self.use_gradient_caching = reranker_config['use_gradient_caching'] if 'gradient_cache_batch_size' in reranker_config: self.gradient_cache_batch_size = reranker_config['gradient_cache_batch_size'] if 'eval_group_size' in reranker_config: self.eval_group_size = reranker_config['eval_group_size'] if 'metric_for_best_model' in reranker_config: self.metric_for_best_model = reranker_config['metric_for_best_model'] if 'eval_normalize_scores' in reranker_config: self.eval_normalize_scores = reranker_config['eval_normalize_scores'] if 'listwise_weight' in reranker_config: self.listwise_weight = reranker_config['listwise_weight'] if 'pairwise_weight' in reranker_config: self.pairwise_weight = reranker_config['pairwise_weight'] if 'filter_empty_passages' in reranker_config: self.filter_empty_passages = reranker_config['filter_empty_passages'] if 'min_passage_length' in reranker_config: self.min_passage_length = reranker_config['min_passage_length'] if 'max_passage_length_ratio' in reranker_config: self.max_passage_length_ratio = reranker_config['max_passage_length_ratio'] if 'gradient_checkpointing' in reranker_config: self.gradient_checkpointing = reranker_config['gradient_checkpointing'] if 'recompute_scores' in reranker_config: self.recompute_scores = reranker_config['recompute_scores'] if 'use_dynamic_teacher' in reranker_config: self.use_dynamic_teacher = reranker_config['use_dynamic_teacher'] if 'teacher_update_steps' in reranker_config: self.teacher_update_steps = reranker_config['teacher_update_steps'] if 'teacher_momentum' in reranker_config: self.teacher_momentum = reranker_config['teacher_momentum'] # Log parameter overrides logger.info("Loaded BGE-Reranker config from [reranker] section in config.toml") # Check for required parameters if self.loss_type not in ["listwise_ce", "pairwise_margin", "combined", "cross_entropy", "binary_cross_entropy"]: raise ValueError("loss_type must be one of ['listwise_ce', 'pairwise_margin', 'combined', 'cross_entropy', 'binary_cross_entropy']") if self.train_group_size is None or self.train_group_size < 2: raise ValueError("train_group_size must be at least 2 and not None for reranking") if self.use_knowledge_distillation and not self.teacher_model_name_or_path: raise ValueError("teacher_model_name_or_path required when use_knowledge_distillation=True") def validate(self): """Validate configuration settings""" valid_loss_types = ["listwise_ce", "pairwise_margin", "combined"] if self.loss_type not in valid_loss_types: raise ValueError(f"loss_type must be one of {valid_loss_types}") if self.train_group_size < 2: raise ValueError("train_group_size must be at least 2 for reranking") if self.use_knowledge_distillation and not self.teacher_model_name_or_path: raise ValueError("teacher_model_name_or_path required when use_knowledge_distillation=True") if self.hard_negative_ratio < 0 or self.hard_negative_ratio > 1: raise ValueError("hard_negative_ratio must be between 0 and 1") if self.max_seq_length > 512 and "base" in self.model_name_or_path: logger.warning("max_seq_length > 512 may not be optimal for base models") if self.gradient_cache_batch_size > self.per_device_train_batch_size: logger.warning("gradient_cache_batch_size > per_device_train_batch_size may not improve efficiency") return True