""" BGE-M3 specific configuration """ from dataclasses import dataclass, field from typing import Optional, List, Dict from .base_config import BaseConfig import logging logger = logging.getLogger(__name__) @dataclass class BGEM3Config(BaseConfig): """ Configuration for BGE-M3 model and trainer. """ use_in_batch_negatives: bool = False # Use in-batch negatives for contrastive loss teacher_temperature: float = 1.0 # Temperature for teacher distillation kd_loss_weight: float = 1.0 # Weight for KD loss kd_loss_type: str = "mse" # Type of KD loss (e.g., mse, kl) hard_negative_margin: float = 0.0 # Margin for hard negative mining """Configuration for BGE-M3 embedding model fine-tuning Attributes: model_name_or_path: Path or identifier for the pretrained BGE-M3 model. use_dense: Enable dense embedding head. use_sparse: Enable sparse (SPLADE-style) embedding head. use_colbert: Enable ColBERT multi-vector head. unified_finetuning: Unified fine-tuning for all functionalities. sentence_pooling_method: Pooling method for sentence embeddings (e.g., 'cls'). normalize_embeddings: Whether to normalize output embeddings. colbert_dim: Output dimension for ColBERT head (-1 uses model hidden size). sparse_top_k: Top-k for sparse representations. temperature: InfoNCE loss temperature. train_group_size: Number of passages per query in training. negatives_cross_device: Share negatives across devices/GPUs. use_hard_negatives: Enable hard negative mining. num_hard_negatives: Number of hard negatives to mine. hard_negative_mining_steps: Steps between hard negative mining refreshes. ance_negative_range: Range for ANCE negative mining. ance_margin: Margin for ANCE mining. use_self_distill: Enable self-knowledge distillation. self_distill_temperature: Temperature for self-distillation. self_distill_alpha: Alpha (weight) for self-distillation loss. use_query_instruction: Add instruction to queries. query_instruction_for_retrieval: Instruction for retrieval queries. query_instruction_for_reranking: Instruction for reranking queries. augment_queries: Enable query augmentation. augmentation_methods: List of augmentation methods to use. create_hard_negatives: Create hard negatives via augmentation. dense_weight: Loss weight for dense head. sparse_weight: Loss weight for sparse head. colbert_weight: Loss weight for ColBERT head. distillation_weight: Loss weight for distillation. eval_strategy: Evaluation strategy ('epoch', 'steps'). metric_for_best_model: Metric to select best model. greater_is_better: Whether higher metric is better. use_gradient_checkpointing: Enable gradient checkpointing. max_passages_per_query: Max passages per query in memory. max_query_length: Max query length. max_passage_length: Max passage length. max_length: Max total input length (multi-granularity). """ # Model specific model_name_or_path: str = "BAAI/bge-m3" # Multi-functionality settings use_dense: bool = True use_sparse: bool = True use_colbert: bool = True unified_finetuning: bool = True # Architecture settings sentence_pooling_method: str = "cls" # BGE-M3 uses CLS pooling normalize_embeddings: bool = True colbert_dim: int = -1 # -1 means use model hidden size sparse_top_k: int = 100 # Top-k for sparse representations # Training specific temperature: float = 0.02 # InfoNCE temperature (from BGE-M3 paper) train_group_size: int = 8 # Number of passages per query negatives_cross_device: bool = True # Share negatives across GPUs # Hard negative mining use_hard_negatives: bool = True num_hard_negatives: int = 15 hard_negative_mining_steps: int = 1000 ance_negative_range: tuple = (10, 100) ance_margin: float = 0.1 # Self-distillation use_self_distill: bool = False # Changed default to False - make KD opt-in self_distill_temperature: float = 4.0 self_distill_alpha: float = 0.5 # Query instruction use_query_instruction: bool = False query_instruction_for_retrieval: str = "Represent this sentence for searching relevant passages:" query_instruction_for_reranking: str = "" # Data augmentation augment_queries: bool = False augmentation_methods: List[str] = field(default_factory=lambda: ["paraphrase", "back_translation"]) create_hard_negatives: bool = False # Loss weights for multi-task learning dense_weight: float = 1.0 sparse_weight: float = 0.3 colbert_weight: float = 0.5 distillation_weight: float = 1.0 # Evaluation eval_strategy: str = "epoch" metric_for_best_model: str = "eval_recall@10" greater_is_better: bool = True # Memory optimization use_gradient_checkpointing: bool = True max_passages_per_query: int = 32 # Specific to M3's multi-granularity max_query_length: int = 64 max_passage_length: int = 512 max_length: int = 8192 # BGE-M3 supports up to 8192 tokens def __post_init__(self): """Load model-specific parameters from [m3] section in config.toml""" super().__post_init__() try: from utils.config_loader import get_config except ImportError: from bge_finetune.utils.config_loader import get_config # type: ignore[import] config = get_config() self._apply_m3_config(config) def _apply_m3_config(self, config): """ Apply BGE-M3 specific configuration from [m3] section in config.toml. """ m3_config = config.get_section('m3') if not m3_config: logger.warning("No [m3] section found in config.toml") return # Helper to convert tuple from config to list if needed def get_val(key, default): if key not in m3_config: return default v = m3_config[key] if isinstance(default, list) and isinstance(v, tuple): return list(v) return v # Only update if key exists in config if 'use_dense' in m3_config: self.use_dense = m3_config['use_dense'] if 'use_sparse' in m3_config: self.use_sparse = m3_config['use_sparse'] if 'use_colbert' in m3_config: self.use_colbert = m3_config['use_colbert'] if 'unified_finetuning' in m3_config: self.unified_finetuning = m3_config['unified_finetuning'] if 'sentence_pooling_method' in m3_config: self.sentence_pooling_method = m3_config['sentence_pooling_method'] if 'normalize_embeddings' in m3_config: self.normalize_embeddings = m3_config['normalize_embeddings'] if 'colbert_dim' in m3_config: self.colbert_dim = m3_config['colbert_dim'] if 'sparse_top_k' in m3_config: self.sparse_top_k = m3_config['sparse_top_k'] if 'temperature' in m3_config: self.temperature = m3_config['temperature'] if 'train_group_size' in m3_config: self.train_group_size = m3_config['train_group_size'] if 'negatives_cross_device' in m3_config: self.negatives_cross_device = m3_config['negatives_cross_device'] if 'use_hard_negatives' in m3_config: self.use_hard_negatives = m3_config['use_hard_negatives'] if 'num_hard_negatives' in m3_config: self.num_hard_negatives = m3_config['num_hard_negatives'] if 'hard_negative_mining_steps' in m3_config: self.hard_negative_mining_steps = m3_config['hard_negative_mining_steps'] if 'ance_negative_range' in m3_config: self.ance_negative_range = tuple(get_val('ance_negative_range', self.ance_negative_range)) if 'augmentation_methods' in m3_config: self.augmentation_methods = list(get_val('augmentation_methods', self.augmentation_methods)) if 'ance_margin' in m3_config: self.ance_margin = m3_config['ance_margin'] if 'use_self_distill' in m3_config: self.use_self_distill = m3_config['use_self_distill'] if 'self_distill_temperature' in m3_config: self.self_distill_temperature = m3_config['self_distill_temperature'] if 'self_distill_alpha' in m3_config: self.self_distill_alpha = m3_config['self_distill_alpha'] if 'use_query_instruction' in m3_config: self.use_query_instruction = m3_config['use_query_instruction'] if 'query_instruction_for_retrieval' in m3_config: self.query_instruction_for_retrieval = m3_config['query_instruction_for_retrieval'] if 'query_instruction_for_reranking' in m3_config: self.query_instruction_for_reranking = m3_config['query_instruction_for_reranking'] if 'augment_queries' in m3_config: self.augment_queries = m3_config['augment_queries'] if 'create_hard_negatives' in m3_config: self.create_hard_negatives = m3_config['create_hard_negatives'] if 'dense_weight' in m3_config: self.dense_weight = m3_config['dense_weight'] if 'sparse_weight' in m3_config: self.sparse_weight = m3_config['sparse_weight'] if 'colbert_weight' in m3_config: self.colbert_weight = m3_config['colbert_weight'] if 'distillation_weight' in m3_config: self.distillation_weight = m3_config['distillation_weight'] if 'metric_for_best_model' in m3_config: self.metric_for_best_model = m3_config['metric_for_best_model'] if 'greater_is_better' in m3_config: self.greater_is_better = m3_config['greater_is_better'] if 'use_gradient_checkpointing' in m3_config: self.use_gradient_checkpointing = m3_config['use_gradient_checkpointing'] if 'max_passages_per_query' in m3_config: self.max_passages_per_query = m3_config['max_passages_per_query'] if 'max_query_length' in m3_config: self.max_query_length = m3_config['max_query_length'] if 'max_passage_length' in m3_config: self.max_passage_length = m3_config['max_passage_length'] if 'max_length' in m3_config: self.max_length = m3_config['max_length'] # Log parameter overrides logger.info("Loaded BGE-M3 config from [m3] section in config.toml") # Check for required parameters if self.temperature is None or self.temperature <= 0: raise ValueError("Temperature must be positive and not None") if self.train_group_size is None or self.train_group_size < 2: raise ValueError("train_group_size must be at least 2 and not None") def validate(self): """Validate configuration settings""" if not any([self.use_dense, self.use_sparse, self.use_colbert]): raise ValueError("At least one of use_dense, use_sparse, or use_colbert must be True") if self.unified_finetuning and not all([self.use_dense, self.use_sparse, self.use_colbert]): logger.warning("unified_finetuning is True but not all functionalities are enabled") if self.temperature <= 0: raise ValueError("Temperature must be positive") if self.train_group_size < 2: raise ValueError("train_group_size must be at least 2") if self.max_length > 8192: logger.warning("max_length > 8192 may not be supported by BGE-M3") return True