255 lines
12 KiB
Python
255 lines
12 KiB
Python
"""
|
|
BGE-M3 specific configuration
|
|
"""
|
|
from dataclasses import dataclass, field
|
|
from typing import Optional, List, Dict
|
|
from .base_config import BaseConfig
|
|
import logging
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
@dataclass
|
|
class BGEM3Config(BaseConfig):
|
|
"""
|
|
Configuration for BGE-M3 model and trainer.
|
|
"""
|
|
use_in_batch_negatives: bool = False # Use in-batch negatives for contrastive loss
|
|
teacher_temperature: float = 1.0 # Temperature for teacher distillation
|
|
kd_loss_weight: float = 1.0 # Weight for KD loss
|
|
kd_loss_type: str = "mse" # Type of KD loss (e.g., mse, kl)
|
|
hard_negative_margin: float = 0.0 # Margin for hard negative mining
|
|
"""Configuration for BGE-M3 embedding model fine-tuning
|
|
|
|
Attributes:
|
|
model_name_or_path: Path or identifier for the pretrained BGE-M3 model.
|
|
use_dense: Enable dense embedding head.
|
|
use_sparse: Enable sparse (SPLADE-style) embedding head.
|
|
use_colbert: Enable ColBERT multi-vector head.
|
|
unified_finetuning: Unified fine-tuning for all functionalities.
|
|
sentence_pooling_method: Pooling method for sentence embeddings (e.g., 'cls').
|
|
normalize_embeddings: Whether to normalize output embeddings.
|
|
colbert_dim: Output dimension for ColBERT head (-1 uses model hidden size).
|
|
sparse_top_k: Top-k for sparse representations.
|
|
temperature: InfoNCE loss temperature.
|
|
train_group_size: Number of passages per query in training.
|
|
negatives_cross_device: Share negatives across devices/GPUs.
|
|
use_hard_negatives: Enable hard negative mining.
|
|
num_hard_negatives: Number of hard negatives to mine.
|
|
hard_negative_mining_steps: Steps between hard negative mining refreshes.
|
|
ance_negative_range: Range for ANCE negative mining.
|
|
ance_margin: Margin for ANCE mining.
|
|
use_self_distill: Enable self-knowledge distillation.
|
|
self_distill_temperature: Temperature for self-distillation.
|
|
self_distill_alpha: Alpha (weight) for self-distillation loss.
|
|
use_query_instruction: Add instruction to queries.
|
|
query_instruction_for_retrieval: Instruction for retrieval queries.
|
|
query_instruction_for_reranking: Instruction for reranking queries.
|
|
augment_queries: Enable query augmentation.
|
|
augmentation_methods: List of augmentation methods to use.
|
|
create_hard_negatives: Create hard negatives via augmentation.
|
|
dense_weight: Loss weight for dense head.
|
|
sparse_weight: Loss weight for sparse head.
|
|
colbert_weight: Loss weight for ColBERT head.
|
|
distillation_weight: Loss weight for distillation.
|
|
eval_strategy: Evaluation strategy ('epoch', 'steps').
|
|
metric_for_best_model: Metric to select best model.
|
|
greater_is_better: Whether higher metric is better.
|
|
use_gradient_checkpointing: Enable gradient checkpointing.
|
|
max_passages_per_query: Max passages per query in memory.
|
|
max_query_length: Max query length.
|
|
max_passage_length: Max passage length.
|
|
max_length: Max total input length (multi-granularity).
|
|
"""
|
|
|
|
# Model specific
|
|
model_name_or_path: str = "BAAI/bge-m3"
|
|
|
|
# Multi-functionality settings
|
|
use_dense: bool = True
|
|
use_sparse: bool = True
|
|
use_colbert: bool = True
|
|
unified_finetuning: bool = True
|
|
|
|
# Architecture settings
|
|
sentence_pooling_method: str = "cls" # BGE-M3 uses CLS pooling
|
|
normalize_embeddings: bool = True
|
|
colbert_dim: int = -1 # -1 means use model hidden size
|
|
sparse_top_k: int = 100 # Top-k for sparse representations
|
|
|
|
# Training specific
|
|
temperature: float = 0.02 # InfoNCE temperature (from BGE-M3 paper)
|
|
train_group_size: int = 8 # Number of passages per query
|
|
negatives_cross_device: bool = True # Share negatives across GPUs
|
|
|
|
# Hard negative mining
|
|
use_hard_negatives: bool = True
|
|
num_hard_negatives: int = 15
|
|
hard_negative_mining_steps: int = 1000
|
|
ance_negative_range: tuple = (10, 100)
|
|
ance_margin: float = 0.1
|
|
|
|
# Self-distillation
|
|
use_self_distill: bool = False # Changed default to False - make KD opt-in
|
|
self_distill_temperature: float = 4.0
|
|
self_distill_alpha: float = 0.5
|
|
|
|
# Query instruction
|
|
use_query_instruction: bool = False
|
|
query_instruction_for_retrieval: str = "Represent this sentence for searching relevant passages:"
|
|
query_instruction_for_reranking: str = ""
|
|
|
|
# Data augmentation
|
|
augment_queries: bool = False
|
|
augmentation_methods: List[str] = field(default_factory=lambda: ["paraphrase", "back_translation"])
|
|
create_hard_negatives: bool = False
|
|
|
|
# Loss weights for multi-task learning
|
|
dense_weight: float = 1.0
|
|
sparse_weight: float = 0.3
|
|
colbert_weight: float = 0.5
|
|
distillation_weight: float = 1.0
|
|
|
|
# Evaluation
|
|
eval_strategy: str = "epoch"
|
|
metric_for_best_model: str = "eval_recall@10"
|
|
greater_is_better: bool = True
|
|
|
|
# Memory optimization
|
|
use_gradient_checkpointing: bool = True
|
|
max_passages_per_query: int = 32
|
|
|
|
# Specific to M3's multi-granularity
|
|
max_query_length: int = 64
|
|
max_passage_length: int = 512
|
|
max_length: int = 8192 # BGE-M3 supports up to 8192 tokens
|
|
|
|
def __post_init__(self):
|
|
"""Load model-specific parameters from [m3] section in config.toml"""
|
|
super().__post_init__()
|
|
try:
|
|
from utils.config_loader import get_config
|
|
except ImportError:
|
|
from bge_finetune.utils.config_loader import get_config # type: ignore[import]
|
|
config = get_config()
|
|
self._apply_m3_config(config)
|
|
|
|
def _apply_m3_config(self, config):
|
|
"""
|
|
Apply BGE-M3 specific configuration from [m3] section in config.toml.
|
|
"""
|
|
m3_config = config.get_section('m3')
|
|
if not m3_config:
|
|
logger.warning("No [m3] section found in config.toml")
|
|
return
|
|
|
|
# Helper to convert tuple from config to list if needed
|
|
def get_val(key, default):
|
|
if key not in m3_config:
|
|
return default
|
|
v = m3_config[key]
|
|
if isinstance(default, list) and isinstance(v, tuple):
|
|
return list(v)
|
|
return v
|
|
|
|
# Only update if key exists in config
|
|
if 'use_dense' in m3_config:
|
|
self.use_dense = m3_config['use_dense']
|
|
if 'use_sparse' in m3_config:
|
|
self.use_sparse = m3_config['use_sparse']
|
|
if 'use_colbert' in m3_config:
|
|
self.use_colbert = m3_config['use_colbert']
|
|
if 'unified_finetuning' in m3_config:
|
|
self.unified_finetuning = m3_config['unified_finetuning']
|
|
if 'sentence_pooling_method' in m3_config:
|
|
self.sentence_pooling_method = m3_config['sentence_pooling_method']
|
|
if 'normalize_embeddings' in m3_config:
|
|
self.normalize_embeddings = m3_config['normalize_embeddings']
|
|
if 'colbert_dim' in m3_config:
|
|
self.colbert_dim = m3_config['colbert_dim']
|
|
if 'sparse_top_k' in m3_config:
|
|
self.sparse_top_k = m3_config['sparse_top_k']
|
|
if 'temperature' in m3_config:
|
|
self.temperature = m3_config['temperature']
|
|
if 'train_group_size' in m3_config:
|
|
self.train_group_size = m3_config['train_group_size']
|
|
if 'negatives_cross_device' in m3_config:
|
|
self.negatives_cross_device = m3_config['negatives_cross_device']
|
|
if 'use_hard_negatives' in m3_config:
|
|
self.use_hard_negatives = m3_config['use_hard_negatives']
|
|
if 'num_hard_negatives' in m3_config:
|
|
self.num_hard_negatives = m3_config['num_hard_negatives']
|
|
if 'hard_negative_mining_steps' in m3_config:
|
|
self.hard_negative_mining_steps = m3_config['hard_negative_mining_steps']
|
|
if 'ance_negative_range' in m3_config:
|
|
self.ance_negative_range = tuple(get_val('ance_negative_range', self.ance_negative_range))
|
|
if 'augmentation_methods' in m3_config:
|
|
self.augmentation_methods = list(get_val('augmentation_methods', self.augmentation_methods))
|
|
if 'ance_margin' in m3_config:
|
|
self.ance_margin = m3_config['ance_margin']
|
|
if 'use_self_distill' in m3_config:
|
|
self.use_self_distill = m3_config['use_self_distill']
|
|
if 'self_distill_temperature' in m3_config:
|
|
self.self_distill_temperature = m3_config['self_distill_temperature']
|
|
if 'self_distill_alpha' in m3_config:
|
|
self.self_distill_alpha = m3_config['self_distill_alpha']
|
|
if 'use_query_instruction' in m3_config:
|
|
self.use_query_instruction = m3_config['use_query_instruction']
|
|
if 'query_instruction_for_retrieval' in m3_config:
|
|
self.query_instruction_for_retrieval = m3_config['query_instruction_for_retrieval']
|
|
if 'query_instruction_for_reranking' in m3_config:
|
|
self.query_instruction_for_reranking = m3_config['query_instruction_for_reranking']
|
|
if 'augment_queries' in m3_config:
|
|
self.augment_queries = m3_config['augment_queries']
|
|
if 'create_hard_negatives' in m3_config:
|
|
self.create_hard_negatives = m3_config['create_hard_negatives']
|
|
if 'dense_weight' in m3_config:
|
|
self.dense_weight = m3_config['dense_weight']
|
|
if 'sparse_weight' in m3_config:
|
|
self.sparse_weight = m3_config['sparse_weight']
|
|
if 'colbert_weight' in m3_config:
|
|
self.colbert_weight = m3_config['colbert_weight']
|
|
if 'distillation_weight' in m3_config:
|
|
self.distillation_weight = m3_config['distillation_weight']
|
|
if 'metric_for_best_model' in m3_config:
|
|
self.metric_for_best_model = m3_config['metric_for_best_model']
|
|
if 'greater_is_better' in m3_config:
|
|
self.greater_is_better = m3_config['greater_is_better']
|
|
if 'use_gradient_checkpointing' in m3_config:
|
|
self.use_gradient_checkpointing = m3_config['use_gradient_checkpointing']
|
|
if 'max_passages_per_query' in m3_config:
|
|
self.max_passages_per_query = m3_config['max_passages_per_query']
|
|
if 'max_query_length' in m3_config:
|
|
self.max_query_length = m3_config['max_query_length']
|
|
if 'max_passage_length' in m3_config:
|
|
self.max_passage_length = m3_config['max_passage_length']
|
|
if 'max_length' in m3_config:
|
|
self.max_length = m3_config['max_length']
|
|
|
|
# Log parameter overrides
|
|
logger.info("Loaded BGE-M3 config from [m3] section in config.toml")
|
|
|
|
# Check for required parameters
|
|
if self.temperature is None or self.temperature <= 0:
|
|
raise ValueError("Temperature must be positive and not None")
|
|
if self.train_group_size is None or self.train_group_size < 2:
|
|
raise ValueError("train_group_size must be at least 2 and not None")
|
|
|
|
def validate(self):
|
|
"""Validate configuration settings"""
|
|
if not any([self.use_dense, self.use_sparse, self.use_colbert]):
|
|
raise ValueError("At least one of use_dense, use_sparse, or use_colbert must be True")
|
|
|
|
if self.unified_finetuning and not all([self.use_dense, self.use_sparse, self.use_colbert]):
|
|
logger.warning("unified_finetuning is True but not all functionalities are enabled")
|
|
|
|
if self.temperature <= 0:
|
|
raise ValueError("Temperature must be positive")
|
|
|
|
if self.train_group_size < 2:
|
|
raise ValueError("train_group_size must be at least 2")
|
|
|
|
if self.max_length > 8192:
|
|
logger.warning("max_length > 8192 may not be supported by BGE-M3")
|
|
|
|
return True
|