bge_finetune/config/m3_config.py
2025-07-22 16:55:25 +08:00

255 lines
12 KiB
Python

"""
BGE-M3 specific configuration
"""
from dataclasses import dataclass, field
from typing import Optional, List, Dict
from .base_config import BaseConfig
import logging
logger = logging.getLogger(__name__)
@dataclass
class BGEM3Config(BaseConfig):
"""
Configuration for BGE-M3 model and trainer.
"""
use_in_batch_negatives: bool = False # Use in-batch negatives for contrastive loss
teacher_temperature: float = 1.0 # Temperature for teacher distillation
kd_loss_weight: float = 1.0 # Weight for KD loss
kd_loss_type: str = "mse" # Type of KD loss (e.g., mse, kl)
hard_negative_margin: float = 0.0 # Margin for hard negative mining
"""Configuration for BGE-M3 embedding model fine-tuning
Attributes:
model_name_or_path: Path or identifier for the pretrained BGE-M3 model.
use_dense: Enable dense embedding head.
use_sparse: Enable sparse (SPLADE-style) embedding head.
use_colbert: Enable ColBERT multi-vector head.
unified_finetuning: Unified fine-tuning for all functionalities.
sentence_pooling_method: Pooling method for sentence embeddings (e.g., 'cls').
normalize_embeddings: Whether to normalize output embeddings.
colbert_dim: Output dimension for ColBERT head (-1 uses model hidden size).
sparse_top_k: Top-k for sparse representations.
temperature: InfoNCE loss temperature.
train_group_size: Number of passages per query in training.
negatives_cross_device: Share negatives across devices/GPUs.
use_hard_negatives: Enable hard negative mining.
num_hard_negatives: Number of hard negatives to mine.
hard_negative_mining_steps: Steps between hard negative mining refreshes.
ance_negative_range: Range for ANCE negative mining.
ance_margin: Margin for ANCE mining.
use_self_distill: Enable self-knowledge distillation.
self_distill_temperature: Temperature for self-distillation.
self_distill_alpha: Alpha (weight) for self-distillation loss.
use_query_instruction: Add instruction to queries.
query_instruction_for_retrieval: Instruction for retrieval queries.
query_instruction_for_reranking: Instruction for reranking queries.
augment_queries: Enable query augmentation.
augmentation_methods: List of augmentation methods to use.
create_hard_negatives: Create hard negatives via augmentation.
dense_weight: Loss weight for dense head.
sparse_weight: Loss weight for sparse head.
colbert_weight: Loss weight for ColBERT head.
distillation_weight: Loss weight for distillation.
eval_strategy: Evaluation strategy ('epoch', 'steps').
metric_for_best_model: Metric to select best model.
greater_is_better: Whether higher metric is better.
use_gradient_checkpointing: Enable gradient checkpointing.
max_passages_per_query: Max passages per query in memory.
max_query_length: Max query length.
max_passage_length: Max passage length.
max_length: Max total input length (multi-granularity).
"""
# Model specific
model_name_or_path: str = "BAAI/bge-m3"
# Multi-functionality settings
use_dense: bool = True
use_sparse: bool = True
use_colbert: bool = True
unified_finetuning: bool = True
# Architecture settings
sentence_pooling_method: str = "cls" # BGE-M3 uses CLS pooling
normalize_embeddings: bool = True
colbert_dim: int = -1 # -1 means use model hidden size
sparse_top_k: int = 100 # Top-k for sparse representations
# Training specific
temperature: float = 0.02 # InfoNCE temperature (from BGE-M3 paper)
train_group_size: int = 8 # Number of passages per query
negatives_cross_device: bool = True # Share negatives across GPUs
# Hard negative mining
use_hard_negatives: bool = True
num_hard_negatives: int = 15
hard_negative_mining_steps: int = 1000
ance_negative_range: tuple = (10, 100)
ance_margin: float = 0.1
# Self-distillation
use_self_distill: bool = False # Changed default to False - make KD opt-in
self_distill_temperature: float = 4.0
self_distill_alpha: float = 0.5
# Query instruction
use_query_instruction: bool = False
query_instruction_for_retrieval: str = "Represent this sentence for searching relevant passages:"
query_instruction_for_reranking: str = ""
# Data augmentation
augment_queries: bool = False
augmentation_methods: List[str] = field(default_factory=lambda: ["paraphrase", "back_translation"])
create_hard_negatives: bool = False
# Loss weights for multi-task learning
dense_weight: float = 1.0
sparse_weight: float = 0.3
colbert_weight: float = 0.5
distillation_weight: float = 1.0
# Evaluation
eval_strategy: str = "epoch"
metric_for_best_model: str = "eval_recall@10"
greater_is_better: bool = True
# Memory optimization
use_gradient_checkpointing: bool = True
max_passages_per_query: int = 32
# Specific to M3's multi-granularity
max_query_length: int = 64
max_passage_length: int = 512
max_length: int = 8192 # BGE-M3 supports up to 8192 tokens
def __post_init__(self):
"""Load model-specific parameters from [m3] section in config.toml"""
super().__post_init__()
try:
from utils.config_loader import get_config
except ImportError:
from bge_finetune.utils.config_loader import get_config # type: ignore[import]
config = get_config()
self._apply_m3_config(config)
def _apply_m3_config(self, config):
"""
Apply BGE-M3 specific configuration from [m3] section in config.toml.
"""
m3_config = config.get_section('m3')
if not m3_config:
logger.warning("No [m3] section found in config.toml")
return
# Helper to convert tuple from config to list if needed
def get_val(key, default):
if key not in m3_config:
return default
v = m3_config[key]
if isinstance(default, list) and isinstance(v, tuple):
return list(v)
return v
# Only update if key exists in config
if 'use_dense' in m3_config:
self.use_dense = m3_config['use_dense']
if 'use_sparse' in m3_config:
self.use_sparse = m3_config['use_sparse']
if 'use_colbert' in m3_config:
self.use_colbert = m3_config['use_colbert']
if 'unified_finetuning' in m3_config:
self.unified_finetuning = m3_config['unified_finetuning']
if 'sentence_pooling_method' in m3_config:
self.sentence_pooling_method = m3_config['sentence_pooling_method']
if 'normalize_embeddings' in m3_config:
self.normalize_embeddings = m3_config['normalize_embeddings']
if 'colbert_dim' in m3_config:
self.colbert_dim = m3_config['colbert_dim']
if 'sparse_top_k' in m3_config:
self.sparse_top_k = m3_config['sparse_top_k']
if 'temperature' in m3_config:
self.temperature = m3_config['temperature']
if 'train_group_size' in m3_config:
self.train_group_size = m3_config['train_group_size']
if 'negatives_cross_device' in m3_config:
self.negatives_cross_device = m3_config['negatives_cross_device']
if 'use_hard_negatives' in m3_config:
self.use_hard_negatives = m3_config['use_hard_negatives']
if 'num_hard_negatives' in m3_config:
self.num_hard_negatives = m3_config['num_hard_negatives']
if 'hard_negative_mining_steps' in m3_config:
self.hard_negative_mining_steps = m3_config['hard_negative_mining_steps']
if 'ance_negative_range' in m3_config:
self.ance_negative_range = tuple(get_val('ance_negative_range', self.ance_negative_range))
if 'augmentation_methods' in m3_config:
self.augmentation_methods = list(get_val('augmentation_methods', self.augmentation_methods))
if 'ance_margin' in m3_config:
self.ance_margin = m3_config['ance_margin']
if 'use_self_distill' in m3_config:
self.use_self_distill = m3_config['use_self_distill']
if 'self_distill_temperature' in m3_config:
self.self_distill_temperature = m3_config['self_distill_temperature']
if 'self_distill_alpha' in m3_config:
self.self_distill_alpha = m3_config['self_distill_alpha']
if 'use_query_instruction' in m3_config:
self.use_query_instruction = m3_config['use_query_instruction']
if 'query_instruction_for_retrieval' in m3_config:
self.query_instruction_for_retrieval = m3_config['query_instruction_for_retrieval']
if 'query_instruction_for_reranking' in m3_config:
self.query_instruction_for_reranking = m3_config['query_instruction_for_reranking']
if 'augment_queries' in m3_config:
self.augment_queries = m3_config['augment_queries']
if 'create_hard_negatives' in m3_config:
self.create_hard_negatives = m3_config['create_hard_negatives']
if 'dense_weight' in m3_config:
self.dense_weight = m3_config['dense_weight']
if 'sparse_weight' in m3_config:
self.sparse_weight = m3_config['sparse_weight']
if 'colbert_weight' in m3_config:
self.colbert_weight = m3_config['colbert_weight']
if 'distillation_weight' in m3_config:
self.distillation_weight = m3_config['distillation_weight']
if 'metric_for_best_model' in m3_config:
self.metric_for_best_model = m3_config['metric_for_best_model']
if 'greater_is_better' in m3_config:
self.greater_is_better = m3_config['greater_is_better']
if 'use_gradient_checkpointing' in m3_config:
self.use_gradient_checkpointing = m3_config['use_gradient_checkpointing']
if 'max_passages_per_query' in m3_config:
self.max_passages_per_query = m3_config['max_passages_per_query']
if 'max_query_length' in m3_config:
self.max_query_length = m3_config['max_query_length']
if 'max_passage_length' in m3_config:
self.max_passage_length = m3_config['max_passage_length']
if 'max_length' in m3_config:
self.max_length = m3_config['max_length']
# Log parameter overrides
logger.info("Loaded BGE-M3 config from [m3] section in config.toml")
# Check for required parameters
if self.temperature is None or self.temperature <= 0:
raise ValueError("Temperature must be positive and not None")
if self.train_group_size is None or self.train_group_size < 2:
raise ValueError("train_group_size must be at least 2 and not None")
def validate(self):
"""Validate configuration settings"""
if not any([self.use_dense, self.use_sparse, self.use_colbert]):
raise ValueError("At least one of use_dense, use_sparse, or use_colbert must be True")
if self.unified_finetuning and not all([self.use_dense, self.use_sparse, self.use_colbert]):
logger.warning("unified_finetuning is True but not all functionalities are enabled")
if self.temperature <= 0:
raise ValueError("Temperature must be positive")
if self.train_group_size < 2:
raise ValueError("train_group_size must be at least 2")
if self.max_length > 8192:
logger.warning("max_length > 8192 may not be supported by BGE-M3")
return True