265 lines
13 KiB
Python
265 lines
13 KiB
Python
"""
|
|
BGE-Reranker specific configuration
|
|
"""
|
|
from dataclasses import dataclass, field
|
|
from typing import Optional, List, Dict
|
|
from .base_config import BaseConfig
|
|
import logging
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
@dataclass
|
|
class BGERerankerConfig(BaseConfig):
|
|
"""
|
|
Configuration for BGE-Reranker model and trainer.
|
|
"""
|
|
num_labels: int = 1 # Number of output labels
|
|
listwise_temperature: float = 1.0 # Temperature for listwise loss
|
|
kd_temperature: float = 1.0 # Temperature for KD loss
|
|
kd_loss_weight: float = 1.0 # Weight for KD loss
|
|
denoise_confidence_threshold: float = 0.0 # Threshold for denoising
|
|
"""Configuration for BGE-Reranker cross-encoder fine-tuning
|
|
|
|
Attributes:
|
|
model_name_or_path: Path or identifier for the pretrained BGE-Reranker model.
|
|
add_pooling_layer: Whether to add a pooling layer (not used for cross-encoder).
|
|
use_fp16: Use mixed precision (float16) training.
|
|
train_group_size: Number of passages per query for listwise training.
|
|
loss_type: Loss function type ('listwise_ce', 'pairwise_margin', 'combined').
|
|
margin: Margin for pairwise loss.
|
|
temperature: Temperature for listwise CE loss.
|
|
use_knowledge_distillation: Enable knowledge distillation from teacher model.
|
|
teacher_model_name_or_path: Path to teacher model for distillation.
|
|
distillation_temperature: Temperature for distillation.
|
|
distillation_alpha: Alpha (weight) for distillation loss.
|
|
use_hard_negatives: Enable hard negative mining.
|
|
hard_negative_ratio: Ratio of hard to random negatives.
|
|
use_denoising: Enable denoising strategy.
|
|
noise_probability: Probability of noise for denoising.
|
|
max_pairs_per_device: Max query-passage pairs per device.
|
|
accumulate_gradients_from_all_devices: Accumulate gradients from all devices.
|
|
max_seq_length: Max combined query+passage length.
|
|
query_max_length: Max query length (if separate encoding).
|
|
passage_max_length: Max passage length (if separate encoding).
|
|
shuffle_passages: Shuffle passage order during training.
|
|
position_bias_cutoff: Top-k passages to consider for position bias.
|
|
use_gradient_caching: Enable gradient caching for efficiency.
|
|
gradient_cache_batch_size: Batch size for gradient caching.
|
|
eval_group_size: Number of candidates during evaluation.
|
|
metric_for_best_model: Metric to select best model.
|
|
eval_normalize_scores: Whether to normalize scores during evaluation.
|
|
listwise_weight: Loss weight for listwise loss.
|
|
pairwise_weight: Loss weight for pairwise loss.
|
|
filter_empty_passages: Filter out empty passages.
|
|
min_passage_length: Minimum passage length to keep.
|
|
max_passage_length_ratio: Max ratio between query and passage length.
|
|
gradient_checkpointing: Enable gradient checkpointing.
|
|
recompute_scores: Recompute scores instead of storing all.
|
|
use_dynamic_teacher: Enable dynamic teacher for RocketQAv2-style training.
|
|
teacher_update_steps: Steps between teacher updates.
|
|
teacher_momentum: Momentum for teacher updates.
|
|
"""
|
|
|
|
# Model specific
|
|
model_name_or_path: str = "BAAI/bge-reranker-base"
|
|
|
|
# Architecture settings
|
|
add_pooling_layer: bool = False # Cross-encoder doesn't use pooling
|
|
use_fp16: bool = True # Use mixed precision by default
|
|
|
|
# Training specific
|
|
train_group_size: int = 16 # Number of passages per query for listwise training
|
|
loss_type: str = "listwise_ce" # "listwise_ce", "pairwise_margin", "combined"
|
|
margin: float = 1.0 # For pairwise margin loss
|
|
temperature: float = 1.0 # For listwise CE loss
|
|
|
|
# Knowledge distillation
|
|
use_knowledge_distillation: bool = False
|
|
teacher_model_name_or_path: Optional[str] = None
|
|
distillation_temperature: float = 3.0
|
|
distillation_alpha: float = 0.7
|
|
|
|
# Hard negative strategies
|
|
use_hard_negatives: bool = True
|
|
hard_negative_ratio: float = 0.8 # Ratio of hard to random negatives
|
|
use_denoising: bool = False # Denoising strategy from paper
|
|
noise_probability: float = 0.1
|
|
|
|
# Listwise training specific
|
|
max_pairs_per_device: int = 128 # Maximum query-passage pairs per device
|
|
accumulate_gradients_from_all_devices: bool = True
|
|
|
|
# Cross-encoder specific settings
|
|
max_seq_length: int = 512 # Maximum combined query+passage length
|
|
query_max_length: int = 64 # For separate encoding (if needed)
|
|
passage_max_length: int = 512 # For separate encoding (if needed)
|
|
|
|
# Position bias mitigation
|
|
shuffle_passages: bool = True # Shuffle passage order during training
|
|
position_bias_cutoff: int = 10 # Consider position bias for top-k passages
|
|
|
|
# Advanced training strategies
|
|
use_gradient_caching: bool = True # Cache gradients for efficient training
|
|
gradient_cache_batch_size: int = 32
|
|
|
|
# Evaluation specific
|
|
eval_group_size: int = 100 # Number of candidates during evaluation
|
|
metric_for_best_model: str = "eval_mrr@10"
|
|
eval_normalize_scores: bool = True
|
|
|
|
# Loss combination weights (for combined loss)
|
|
listwise_weight: float = 0.7
|
|
pairwise_weight: float = 0.3
|
|
|
|
# Data filtering
|
|
filter_empty_passages: bool = True
|
|
min_passage_length: int = 10
|
|
max_passage_length_ratio: float = 50.0 # Max ratio between query and passage length
|
|
|
|
# Memory optimization for large group sizes
|
|
gradient_checkpointing: bool = True
|
|
recompute_scores: bool = False # Recompute instead of storing all scores
|
|
|
|
# Dynamic teacher (for RocketQAv2-style training)
|
|
use_dynamic_teacher: bool = False
|
|
teacher_update_steps: int = 1000
|
|
teacher_momentum: float = 0.999
|
|
|
|
def __post_init__(self):
|
|
"""Load model-specific parameters from [reranker] section in config.toml"""
|
|
super().__post_init__()
|
|
try:
|
|
from utils.config_loader import get_config
|
|
except ImportError:
|
|
from bge_finetune.utils.config_loader import get_config # type: ignore[import]
|
|
config = get_config()
|
|
self._apply_reranker_config(config)
|
|
|
|
def _apply_reranker_config(self, config):
|
|
"""
|
|
Apply BGE-Reranker specific configuration from [reranker] section in config.toml.
|
|
"""
|
|
reranker_config = config.get_section('reranker')
|
|
if not reranker_config:
|
|
logger.warning("No [reranker] section found in config.toml")
|
|
return
|
|
|
|
# Helper to convert tuple from config to list if needed
|
|
def get_val(key, default):
|
|
if key not in reranker_config:
|
|
return default
|
|
v = reranker_config[key]
|
|
if isinstance(default, tuple) and isinstance(v, list):
|
|
return tuple(v)
|
|
if isinstance(default, list) and isinstance(v, tuple):
|
|
return list(v)
|
|
return v
|
|
|
|
# Only update if key exists in config
|
|
if 'add_pooling_layer' in reranker_config:
|
|
self.add_pooling_layer = reranker_config['add_pooling_layer']
|
|
if 'use_fp16' in reranker_config:
|
|
self.use_fp16 = reranker_config['use_fp16']
|
|
if 'train_group_size' in reranker_config:
|
|
self.train_group_size = reranker_config['train_group_size']
|
|
if 'loss_type' in reranker_config:
|
|
self.loss_type = reranker_config['loss_type']
|
|
if 'margin' in reranker_config:
|
|
self.margin = reranker_config['margin']
|
|
if 'temperature' in reranker_config:
|
|
self.temperature = reranker_config['temperature']
|
|
if 'use_knowledge_distillation' in reranker_config:
|
|
self.use_knowledge_distillation = reranker_config['use_knowledge_distillation']
|
|
if 'teacher_model_name_or_path' in reranker_config:
|
|
self.teacher_model_name_or_path = reranker_config['teacher_model_name_or_path']
|
|
if 'distillation_temperature' in reranker_config:
|
|
self.distillation_temperature = reranker_config['distillation_temperature']
|
|
if 'distillation_alpha' in reranker_config:
|
|
self.distillation_alpha = reranker_config['distillation_alpha']
|
|
if 'use_hard_negatives' in reranker_config:
|
|
self.use_hard_negatives = reranker_config['use_hard_negatives']
|
|
if 'hard_negative_ratio' in reranker_config:
|
|
self.hard_negative_ratio = reranker_config['hard_negative_ratio']
|
|
if 'use_denoising' in reranker_config:
|
|
self.use_denoising = reranker_config['use_denoising']
|
|
if 'noise_probability' in reranker_config:
|
|
self.noise_probability = reranker_config['noise_probability']
|
|
if 'max_pairs_per_device' in reranker_config:
|
|
self.max_pairs_per_device = reranker_config['max_pairs_per_device']
|
|
if 'accumulate_gradients_from_all_devices' in reranker_config:
|
|
self.accumulate_gradients_from_all_devices = reranker_config['accumulate_gradients_from_all_devices']
|
|
if 'max_seq_length' in reranker_config:
|
|
self.max_seq_length = reranker_config['max_seq_length']
|
|
if 'query_max_length' in reranker_config:
|
|
self.query_max_length = reranker_config['query_max_length']
|
|
if 'passage_max_length' in reranker_config:
|
|
self.passage_max_length = reranker_config['passage_max_length']
|
|
if 'shuffle_passages' in reranker_config:
|
|
self.shuffle_passages = reranker_config['shuffle_passages']
|
|
if 'position_bias_cutoff' in reranker_config:
|
|
self.position_bias_cutoff = reranker_config['position_bias_cutoff']
|
|
if 'use_gradient_caching' in reranker_config:
|
|
self.use_gradient_caching = reranker_config['use_gradient_caching']
|
|
if 'gradient_cache_batch_size' in reranker_config:
|
|
self.gradient_cache_batch_size = reranker_config['gradient_cache_batch_size']
|
|
if 'eval_group_size' in reranker_config:
|
|
self.eval_group_size = reranker_config['eval_group_size']
|
|
if 'metric_for_best_model' in reranker_config:
|
|
self.metric_for_best_model = reranker_config['metric_for_best_model']
|
|
if 'eval_normalize_scores' in reranker_config:
|
|
self.eval_normalize_scores = reranker_config['eval_normalize_scores']
|
|
if 'listwise_weight' in reranker_config:
|
|
self.listwise_weight = reranker_config['listwise_weight']
|
|
if 'pairwise_weight' in reranker_config:
|
|
self.pairwise_weight = reranker_config['pairwise_weight']
|
|
if 'filter_empty_passages' in reranker_config:
|
|
self.filter_empty_passages = reranker_config['filter_empty_passages']
|
|
if 'min_passage_length' in reranker_config:
|
|
self.min_passage_length = reranker_config['min_passage_length']
|
|
if 'max_passage_length_ratio' in reranker_config:
|
|
self.max_passage_length_ratio = reranker_config['max_passage_length_ratio']
|
|
if 'gradient_checkpointing' in reranker_config:
|
|
self.gradient_checkpointing = reranker_config['gradient_checkpointing']
|
|
if 'recompute_scores' in reranker_config:
|
|
self.recompute_scores = reranker_config['recompute_scores']
|
|
if 'use_dynamic_teacher' in reranker_config:
|
|
self.use_dynamic_teacher = reranker_config['use_dynamic_teacher']
|
|
if 'teacher_update_steps' in reranker_config:
|
|
self.teacher_update_steps = reranker_config['teacher_update_steps']
|
|
if 'teacher_momentum' in reranker_config:
|
|
self.teacher_momentum = reranker_config['teacher_momentum']
|
|
|
|
# Log parameter overrides
|
|
logger.info("Loaded BGE-Reranker config from [reranker] section in config.toml")
|
|
|
|
# Check for required parameters
|
|
if self.loss_type not in ["listwise_ce", "pairwise_margin", "combined", "cross_entropy", "binary_cross_entropy"]:
|
|
raise ValueError("loss_type must be one of ['listwise_ce', 'pairwise_margin', 'combined', 'cross_entropy', 'binary_cross_entropy']")
|
|
if self.train_group_size is None or self.train_group_size < 2:
|
|
raise ValueError("train_group_size must be at least 2 and not None for reranking")
|
|
if self.use_knowledge_distillation and not self.teacher_model_name_or_path:
|
|
raise ValueError("teacher_model_name_or_path required when use_knowledge_distillation=True")
|
|
|
|
def validate(self):
|
|
"""Validate configuration settings"""
|
|
valid_loss_types = ["listwise_ce", "pairwise_margin", "combined"]
|
|
if self.loss_type not in valid_loss_types:
|
|
raise ValueError(f"loss_type must be one of {valid_loss_types}")
|
|
|
|
if self.train_group_size < 2:
|
|
raise ValueError("train_group_size must be at least 2 for reranking")
|
|
|
|
if self.use_knowledge_distillation and not self.teacher_model_name_or_path:
|
|
raise ValueError("teacher_model_name_or_path required when use_knowledge_distillation=True")
|
|
|
|
if self.hard_negative_ratio < 0 or self.hard_negative_ratio > 1:
|
|
raise ValueError("hard_negative_ratio must be between 0 and 1")
|
|
|
|
if self.max_seq_length > 512 and "base" in self.model_name_or_path:
|
|
logger.warning("max_seq_length > 512 may not be optimal for base models")
|
|
|
|
if self.gradient_cache_batch_size > self.per_device_train_batch_size:
|
|
logger.warning("gradient_cache_batch_size > per_device_train_batch_size may not improve efficiency")
|
|
|
|
return True
|