895 lines
38 KiB
Python
895 lines
38 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
from torch.utils.data import DataLoader
|
|
from torch.optim import AdamW
|
|
from transformers import AutoTokenizer
|
|
import os
|
|
import json
|
|
import logging
|
|
from tqdm import tqdm
|
|
from typing import Dict, Optional, List, Tuple
|
|
import numpy as np
|
|
from torch.cuda.amp.autocast_mode import autocast
|
|
from torch.cuda.amp.grad_scaler import GradScaler
|
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
|
import torch.distributed as dist
|
|
|
|
from .trainer import BaseTrainer
|
|
from models.bge_m3 import BGEM3Model
|
|
from models.losses import InfoNCELoss, M3KnowledgeDistillationLoss, MultiTaskLoss
|
|
from data.dataset import BGEM3Dataset
|
|
|
|
# Optional hard negative mining
|
|
try:
|
|
from data.hard_negative_mining import ANCEHardNegativeMiner
|
|
ANCE_AVAILABLE = True
|
|
except ImportError:
|
|
ANCE_AVAILABLE = False
|
|
ANCEHardNegativeMiner = None
|
|
|
|
from utils.checkpoint import CheckpointManager
|
|
from utils.logging import get_logger
|
|
from evaluation.metrics import compute_retrieval_metrics
|
|
from config.m3_config import BGEM3Config
|
|
|
|
# Import scheduler with fallback
|
|
try:
|
|
from transformers import get_linear_schedule_with_warmup
|
|
except ImportError:
|
|
get_linear_schedule_with_warmup = None
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
|
|
class BGEM3Trainer(BaseTrainer):
|
|
"""
|
|
Trainer for BGE-M3 embedding model with all SOTA techniques
|
|
|
|
Notes:
|
|
- All parameters (config, model, tokenizer, etc.) should be passed from config-driven training scripts.
|
|
- This class does not load config directly; pass config values from your main script for consistency.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
config: BGEM3Config,
|
|
model: Optional[BGEM3Model] = None,
|
|
tokenizer: Optional[AutoTokenizer] = None,
|
|
optimizer: Optional[torch.optim.Optimizer] = None,
|
|
scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
|
|
device: Optional[torch.device] = None
|
|
):
|
|
# Store config first
|
|
self.config = config
|
|
|
|
# Initialize tokenizer first
|
|
if tokenizer is None:
|
|
self.tokenizer = AutoTokenizer.from_pretrained(
|
|
config.model_name_or_path,
|
|
cache_dir=config.cache_dir,
|
|
trust_remote_code=True
|
|
)
|
|
else:
|
|
self.tokenizer = tokenizer
|
|
|
|
# Initialize model
|
|
if model is None:
|
|
self.model = BGEM3Model(
|
|
model_name_or_path=config.model_name_or_path,
|
|
use_dense=config.use_dense,
|
|
use_sparse=config.use_sparse,
|
|
use_colbert=config.use_colbert,
|
|
pooling_method=config.sentence_pooling_method,
|
|
normalize_embeddings=config.normalize_embeddings,
|
|
colbert_dim=config.colbert_dim,
|
|
sparse_top_k=config.sparse_top_k,
|
|
cache_dir=config.cache_dir
|
|
)
|
|
else:
|
|
self.model = model
|
|
|
|
# Initialize optimizer if not provided
|
|
if optimizer is None:
|
|
optimizer = self._create_optimizer()
|
|
|
|
# Initialize scheduler if not provided and get_linear_schedule_with_warmup is available
|
|
if scheduler is None and get_linear_schedule_with_warmup is not None:
|
|
# We'll create it later when we know the total training steps
|
|
pass
|
|
|
|
# Initialize parent class
|
|
super().__init__(config, self.model, optimizer, scheduler, device)
|
|
|
|
# Initialize losses
|
|
self.infonce_loss = InfoNCELoss(
|
|
temperature=config.temperature,
|
|
reduction='mean'
|
|
)
|
|
|
|
self.kd_loss = M3KnowledgeDistillationLoss(
|
|
temperature=config.self_distill_temperature,
|
|
alpha=config.self_distill_alpha,
|
|
distill_types=['dense', 'sparse', 'colbert']
|
|
)
|
|
|
|
self.multi_task_loss = MultiTaskLoss(
|
|
loss_weights={
|
|
'dense': config.dense_weight,
|
|
'sparse': config.sparse_weight,
|
|
'colbert': config.colbert_weight,
|
|
'distillation': config.distillation_weight
|
|
}
|
|
)
|
|
|
|
# Initialize hard negative miner
|
|
if config.use_hard_negatives and ANCE_AVAILABLE:
|
|
self.hard_negative_miner = ANCEHardNegativeMiner(
|
|
embedding_dim=self.model.config.hidden_size, # type: ignore[attr-defined]
|
|
num_hard_negatives=config.num_hard_negatives,
|
|
negative_range=config.ance_negative_range,
|
|
margin=config.ance_margin,
|
|
index_refresh_interval=config.hard_negative_mining_steps,
|
|
use_gpu=str(self.device) == "cuda"
|
|
)
|
|
self.hard_negative_miner.start_async_updates()
|
|
else:
|
|
if config.use_hard_negatives and not ANCE_AVAILABLE:
|
|
logger.warning("Hard negative mining requested but ANCEHardNegativeMiner not available")
|
|
self.hard_negative_miner = None
|
|
|
|
logger.info(f"Initialized BGE-M3 trainer with config: {config.to_dict()}")
|
|
|
|
def _create_optimizer(self) -> torch.optim.Optimizer:
|
|
"""Create optimizer with parameter groups"""
|
|
# Separate parameters for different learning rates if needed
|
|
param_groups = [
|
|
{
|
|
'params': [p for n, p in self.model.named_parameters()
|
|
if 'colbert_linear' not in n and 'sparse_linear' not in n],
|
|
'lr': self.config.learning_rate
|
|
}
|
|
]
|
|
|
|
# Different LR for task-specific layers
|
|
if self.config.use_colbert and hasattr(self.model, 'colbert_linear'):
|
|
param_groups.append({
|
|
'params': self.model.colbert_linear.parameters(),
|
|
'lr': self.config.learning_rate * 2 # Higher LR for new layers
|
|
})
|
|
|
|
if self.config.use_sparse and hasattr(self.model, 'sparse_linear'):
|
|
param_groups.append({
|
|
'params': self.model.sparse_linear.parameters(),
|
|
'lr': self.config.learning_rate * 2
|
|
})
|
|
|
|
optimizer = AdamW(
|
|
param_groups,
|
|
eps=self.config.adam_epsilon,
|
|
weight_decay=self.config.weight_decay,
|
|
betas=(self.config.adam_beta1, self.config.adam_beta2)
|
|
)
|
|
|
|
return optimizer
|
|
|
|
def create_scheduler(self, num_training_steps: int):
|
|
"""Create learning rate scheduler"""
|
|
if get_linear_schedule_with_warmup is None:
|
|
logger.warning("transformers not available, using constant learning rate")
|
|
return None
|
|
|
|
num_warmup_steps = self.config.warmup_steps or int(num_training_steps * self.config.warmup_ratio)
|
|
|
|
self.scheduler = get_linear_schedule_with_warmup(
|
|
self.optimizer,
|
|
num_warmup_steps=num_warmup_steps,
|
|
num_training_steps=num_training_steps
|
|
)
|
|
return self.scheduler
|
|
|
|
def compute_loss(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
|
|
"""Compute M3 loss with multiple objectives"""
|
|
|
|
# Update hard negative mining if hook is available
|
|
if hasattr(self, '_mining_hook'):
|
|
self._mining_hook()
|
|
|
|
# Detect batch format and handle accordingly
|
|
if 'pos_input_ids' in batch and 'neg_input_ids' in batch:
|
|
# This is triplet format from FastM3Dataset
|
|
return self._compute_triplet_loss(batch)
|
|
elif 'passage_input_ids' in batch:
|
|
# This is standard embedding format
|
|
return self._compute_embedding_loss(batch)
|
|
else:
|
|
raise ValueError(f"Unknown batch format. Keys: {list(batch.keys())}")
|
|
|
|
def _compute_triplet_loss(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
|
|
"""Compute loss for triplet format data (FastM3Dataset)"""
|
|
# Get query embeddings
|
|
query_outputs = self.model.forward(
|
|
input_ids=batch['query_input_ids'],
|
|
attention_mask=batch['query_attention_mask'],
|
|
return_dense=self.config.use_dense,
|
|
return_sparse=self.config.use_sparse,
|
|
return_colbert=self.config.use_colbert,
|
|
return_dict=True
|
|
)
|
|
|
|
# Get positive embeddings
|
|
pos_outputs = self.model.forward(
|
|
input_ids=batch['pos_input_ids'],
|
|
attention_mask=batch['pos_attention_mask'],
|
|
return_dense=self.config.use_dense,
|
|
return_sparse=self.config.use_sparse,
|
|
return_colbert=self.config.use_colbert,
|
|
return_dict=True
|
|
)
|
|
|
|
# Get negative embeddings
|
|
neg_outputs = self.model.forward(
|
|
input_ids=batch['neg_input_ids'],
|
|
attention_mask=batch['neg_attention_mask'],
|
|
return_dense=self.config.use_dense,
|
|
return_sparse=self.config.use_sparse,
|
|
return_colbert=self.config.use_colbert,
|
|
return_dict=True
|
|
)
|
|
|
|
losses = {}
|
|
|
|
# Dense loss (InfoNCE with explicit pos/neg)
|
|
if 'dense' in query_outputs:
|
|
query_dense = query_outputs['dense'].float() # type: ignore[index]
|
|
pos_dense = pos_outputs['dense'].float() # type: ignore[index]
|
|
neg_dense = neg_outputs['dense'].float() # type: ignore[index]
|
|
|
|
# For triplet format, reshape negative embeddings to expected 3D format
|
|
# InfoNCE expects negative_embeddings: [batch_size, num_negatives, embedding_dim]
|
|
# But we have single negatives: [batch_size, embedding_dim]
|
|
# Reshape to [batch_size, 1, embedding_dim]
|
|
neg_dense_3d = neg_dense.unsqueeze(1) # [batch_size, 1, embedding_dim]
|
|
|
|
dense_loss = self.infonce_loss(
|
|
query_dense,
|
|
pos_dense,
|
|
neg_dense_3d # Now properly shaped 3D tensor
|
|
)
|
|
losses['dense'] = dense_loss
|
|
|
|
# Sparse loss
|
|
if self.config.use_sparse and 'sparse' in query_outputs:
|
|
from models.losses import SparseLoss
|
|
sparse_loss_fn = SparseLoss(
|
|
flops_loss_weight=1e-3,
|
|
sparse_reg_weight=1e-4
|
|
)
|
|
|
|
# Combine pos and neg for sparse loss
|
|
combined_sparse = torch.cat([pos_outputs['sparse'], neg_outputs['sparse']], dim=0) # type: ignore[index]
|
|
query_sparse_repeated = query_outputs['sparse'].repeat(2, 1) # type: ignore[index]
|
|
labels = torch.cat([torch.ones(len(pos_outputs['sparse'])), # type: ignore[index]
|
|
torch.zeros(len(neg_outputs['sparse']))], dim=0).to(query_sparse_repeated.device) # type: ignore[index]
|
|
|
|
sparse_loss, _ = sparse_loss_fn(
|
|
query_sparse_repeated,
|
|
combined_sparse,
|
|
labels
|
|
)
|
|
losses['sparse'] = sparse_loss
|
|
|
|
# ColBERT loss
|
|
if self.config.use_colbert and 'colbert' in query_outputs:
|
|
from models.losses import ColBERTLoss
|
|
colbert_loss_fn = ColBERTLoss(temperature=self.config.temperature)
|
|
|
|
# Combine pos and neg for ColBERT loss
|
|
combined_colbert = torch.cat([pos_outputs['colbert'], neg_outputs['colbert']], dim=0) # type: ignore[index]
|
|
query_colbert_repeated = query_outputs['colbert'].repeat(2, 1, 1) # type: ignore[index]
|
|
labels = torch.cat([torch.ones(len(pos_outputs['colbert'])), # type: ignore[index]
|
|
torch.zeros(len(neg_outputs['colbert']))], dim=0).to(query_colbert_repeated.device) # type: ignore[index]
|
|
|
|
colbert_loss = colbert_loss_fn(
|
|
query_colbert_repeated,
|
|
combined_colbert,
|
|
labels
|
|
)
|
|
losses['colbert'] = colbert_loss
|
|
|
|
# Combine losses
|
|
total_loss, _ = self.multi_task_loss(losses)
|
|
return total_loss
|
|
|
|
def _compute_embedding_loss(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
|
|
"""Compute loss for standard embedding format data"""
|
|
# Get query embeddings
|
|
query_outputs = self.model.forward(
|
|
input_ids=batch['query_input_ids'],
|
|
attention_mask=batch['query_attention_mask'],
|
|
return_dense=self.config.use_dense,
|
|
return_sparse=self.config.use_sparse,
|
|
return_colbert=self.config.use_colbert,
|
|
return_dict=True
|
|
)
|
|
|
|
# Get passage embeddings - handle both single and multiple passages
|
|
if batch['passage_input_ids'].dim() == 3:
|
|
# Multiple passages per query [batch, num_passages, seq_len]
|
|
batch_size, num_passages, seq_len = batch['passage_input_ids'].shape
|
|
passage_input_ids = batch['passage_input_ids'].view(batch_size * num_passages, seq_len)
|
|
passage_attention_mask = batch['passage_attention_mask'].view(batch_size * num_passages, seq_len)
|
|
multiple_passages = True
|
|
else:
|
|
# Single passage per query [batch, seq_len]
|
|
passage_input_ids = batch['passage_input_ids']
|
|
passage_attention_mask = batch['passage_attention_mask']
|
|
batch_size = passage_input_ids.shape[0]
|
|
num_passages = 1
|
|
multiple_passages = False
|
|
|
|
passage_outputs = self.model.forward(
|
|
input_ids=passage_input_ids,
|
|
attention_mask=passage_attention_mask,
|
|
return_dense=True,
|
|
return_sparse=self.config.use_sparse,
|
|
return_colbert=self.config.use_colbert,
|
|
return_dict=True
|
|
)
|
|
|
|
# Compute losses for different components
|
|
losses = {}
|
|
|
|
# Dense loss (InfoNCE)
|
|
if 'dense' in query_outputs and 'dense' in passage_outputs:
|
|
# Get embeddings and ensure they're on the same device
|
|
query_dense = query_outputs['dense'].float() # type: ignore[index] # [batch_size, embedding_dim]
|
|
passage_dense = passage_outputs['dense'].float() # type: ignore[index] # [batch_size * num_passages, embedding_dim]
|
|
|
|
# Handle multiple passages per query differently for contrastive learning
|
|
if multiple_passages:
|
|
# Reshape passage embeddings back to [batch_size, num_passages, embedding_dim]
|
|
passage_dense = passage_dense.view(batch_size, num_passages, -1)
|
|
|
|
# For contrastive learning, we need to identify positive passages
|
|
# Use the labels to find the first positive passage for each query
|
|
labels = batch.get('labels') # [batch_size, num_passages]
|
|
|
|
if labels is not None:
|
|
# Find the first positive passage for each query
|
|
positive_indices = []
|
|
negative_passages = []
|
|
|
|
for i in range(batch_size):
|
|
query_labels = labels[i] # [num_passages]
|
|
positive_mask = query_labels > 0
|
|
|
|
if positive_mask.any():
|
|
# Get first positive passage
|
|
pos_idx = positive_mask.nonzero(as_tuple=True)[0][0].item()
|
|
positive_indices.append(pos_idx)
|
|
|
|
# Collect negative passages
|
|
neg_mask = ~positive_mask
|
|
if neg_mask.any():
|
|
neg_embeddings = passage_dense[i][neg_mask] # [num_negatives, embedding_dim]
|
|
negative_passages.append(neg_embeddings)
|
|
else:
|
|
negative_passages.append(None)
|
|
else:
|
|
# No positive passage, use first passage as positive (fallback)
|
|
positive_indices.append(0)
|
|
negative_passages.append(passage_dense[i][1:])
|
|
|
|
# Extract positive embeddings
|
|
positive_embeddings = torch.stack([
|
|
passage_dense[i, positive_indices[i]] for i in range(batch_size)
|
|
]) # [batch_size, embedding_dim]
|
|
|
|
# Stack negative embeddings (if available)
|
|
max_neg = max(neg.shape[0] if neg is not None else 0 for neg in negative_passages)
|
|
if max_neg > 0:
|
|
# Pad negative embeddings to same size
|
|
padded_negatives = []
|
|
for neg in negative_passages:
|
|
if neg is not None and neg.shape[0] > 0:
|
|
if neg.shape[0] < max_neg:
|
|
# Pad with zeros
|
|
padding = torch.zeros(max_neg - neg.shape[0], neg.shape[1], device=neg.device)
|
|
neg = torch.cat([neg, padding], dim=0)
|
|
padded_negatives.append(neg)
|
|
else:
|
|
# No negatives, create zero tensor
|
|
padded_negatives.append(torch.zeros(max_neg, passage_dense.shape[-1], device=passage_dense.device))
|
|
|
|
negative_embeddings = torch.stack(padded_negatives) # [batch_size, max_neg, embedding_dim]
|
|
else:
|
|
negative_embeddings = None
|
|
|
|
# Compute InfoNCE loss with explicit positives and negatives
|
|
dense_loss = self.infonce_loss(
|
|
query_dense,
|
|
positive_embeddings,
|
|
negative_embeddings
|
|
)
|
|
else:
|
|
# No labels provided, use first passage as positive, rest as negatives
|
|
positive_embeddings = passage_dense[:, 0, :] # [batch_size, embedding_dim]
|
|
if num_passages > 1:
|
|
negative_embeddings = passage_dense[:, 1:, :] # [batch_size, num_passages-1, embedding_dim]
|
|
else:
|
|
negative_embeddings = None
|
|
|
|
dense_loss = self.infonce_loss(
|
|
query_dense,
|
|
positive_embeddings,
|
|
negative_embeddings
|
|
)
|
|
else:
|
|
# Single passage per query - use standard in-batch negatives
|
|
# InfoNCE will create contrastive pairs using in-batch negatives
|
|
dense_loss = self.infonce_loss(
|
|
query_dense,
|
|
passage_dense,
|
|
None # No explicit negatives, will use in-batch
|
|
)
|
|
|
|
losses['dense'] = dense_loss
|
|
|
|
# Sparse loss (if using sparse)
|
|
if self.config.use_sparse and 'sparse' in query_outputs and 'sparse' in passage_outputs:
|
|
from models.losses import SparseLoss
|
|
sparse_loss_fn = SparseLoss(
|
|
flops_loss_weight=1e-3,
|
|
sparse_reg_weight=1e-4
|
|
)
|
|
|
|
query_sparse = query_outputs['sparse'] # type: ignore[index] # [batch_size, query_seq_len]
|
|
passage_sparse = passage_outputs['sparse'] # type: ignore[index] # [batch_size * num_passages, doc_seq_len]
|
|
|
|
# For sparse loss, we need to handle the contrastive learning differently
|
|
# Since sparse embeddings have different sequence lengths, we'll use the same approach as dense
|
|
if multiple_passages:
|
|
# Use only the first positive passage for each query
|
|
labels = batch.get('labels')
|
|
if labels is not None:
|
|
positive_indices = []
|
|
for i in range(batch_size):
|
|
query_labels = labels[i]
|
|
positive_mask = query_labels > 0
|
|
if positive_mask.any():
|
|
pos_idx = positive_mask.nonzero(as_tuple=True)[0][0].item()
|
|
else:
|
|
pos_idx = 0
|
|
positive_indices.append(pos_idx + i * num_passages)
|
|
|
|
# Extract positive sparse embeddings
|
|
positive_sparse = passage_sparse[positive_indices] # [batch_size, doc_seq_len]
|
|
else:
|
|
# Use first passage as positive
|
|
positive_sparse = passage_sparse[::num_passages] # [batch_size, doc_seq_len]
|
|
|
|
# Create contrastive labels for sparse loss
|
|
contrastive_labels = torch.arange(batch_size, device=query_sparse.device, dtype=torch.long)
|
|
|
|
sparse_loss, _ = sparse_loss_fn(
|
|
query_sparse,
|
|
positive_sparse,
|
|
contrastive_labels
|
|
)
|
|
else:
|
|
# Single passage per query
|
|
contrastive_labels = torch.arange(batch_size, device=query_sparse.device, dtype=torch.long)
|
|
sparse_loss, _ = sparse_loss_fn(
|
|
query_sparse,
|
|
passage_sparse,
|
|
contrastive_labels
|
|
)
|
|
|
|
losses['sparse'] = sparse_loss
|
|
|
|
# ColBERT loss (if using ColBERT)
|
|
if self.config.use_colbert and 'colbert' in query_outputs and 'colbert' in passage_outputs:
|
|
from models.losses import ColBERTLoss
|
|
colbert_loss_fn = ColBERTLoss(temperature=self.config.temperature)
|
|
|
|
query_colbert = query_outputs['colbert'] # type: ignore[index] # [batch_size, query_seq_len, dim]
|
|
passage_colbert = passage_outputs['colbert'] # type: ignore[index] # [batch_size * num_passages, doc_seq_len, dim]
|
|
|
|
# For ColBERT, we need pairwise interactions
|
|
if multiple_passages:
|
|
# Use only positive passages
|
|
labels = batch.get('labels')
|
|
if labels is not None:
|
|
positive_indices = []
|
|
for i in range(batch_size):
|
|
query_labels = labels[i]
|
|
positive_mask = query_labels > 0
|
|
if positive_mask.any():
|
|
pos_idx = positive_mask.nonzero(as_tuple=True)[0][0].item()
|
|
else:
|
|
pos_idx = 0
|
|
positive_indices.append(pos_idx + i * num_passages)
|
|
|
|
positive_colbert = passage_colbert[positive_indices]
|
|
else:
|
|
positive_colbert = passage_colbert[::num_passages]
|
|
|
|
# Binary labels for ColBERT (1 for positive pairs)
|
|
colbert_labels = torch.ones(batch_size, device=query_colbert.device)
|
|
else:
|
|
positive_colbert = passage_colbert
|
|
colbert_labels = torch.ones(batch_size, device=query_colbert.device)
|
|
|
|
colbert_loss = colbert_loss_fn(
|
|
query_colbert,
|
|
positive_colbert,
|
|
colbert_labels
|
|
)
|
|
losses['colbert'] = colbert_loss
|
|
|
|
# Knowledge distillation loss (only if explicitly enabled)
|
|
if (getattr(self.config, 'use_self_distill', False) and
|
|
self.global_step >= getattr(self.config, 'self_distill_start_step', 0)):
|
|
try:
|
|
# Compute teacher scores (weighted combination)
|
|
teacher_scores = self._compute_teacher_scores(query_outputs, passage_outputs) # type: ignore[arg-type]
|
|
student_scores = self.model.compute_similarity(
|
|
query_outputs, passage_outputs, mode='dense'
|
|
)
|
|
|
|
kd_loss, _ = self.kd_loss(
|
|
{'dense': student_scores},
|
|
{'dense': teacher_scores}
|
|
)
|
|
losses['kd'] = kd_loss
|
|
except Exception as e:
|
|
logger.warning(f"Knowledge distillation failed: {e}")
|
|
|
|
# Combine losses
|
|
total_loss, _ = self.multi_task_loss(losses)
|
|
|
|
return total_loss
|
|
|
|
def _compute_teacher_scores(
|
|
self,
|
|
query_outputs: Dict[str, torch.Tensor],
|
|
passage_outputs: Dict[str, torch.Tensor]
|
|
) -> torch.Tensor:
|
|
"""Compute teacher scores from multiple signals"""
|
|
scores = []
|
|
weights = []
|
|
|
|
# Get original batch size for reshaping
|
|
query_batch_size = query_outputs['dense'].shape[0] if 'dense' in query_outputs else 1
|
|
passage_batch_size = passage_outputs['dense'].shape[0] if 'dense' in passage_outputs else 1
|
|
|
|
# Check if we have multiple passages per query
|
|
multiple_passages = passage_batch_size > query_batch_size
|
|
|
|
if multiple_passages:
|
|
# For multiple passages, we need to compute scores only for positive pairs
|
|
# Use only the first passage for each query as a simplified teacher score
|
|
num_passages = passage_batch_size // query_batch_size
|
|
|
|
# Dense scores (use only first passage per query)
|
|
if 'dense' in query_outputs and 'dense' in passage_outputs:
|
|
query_dense = query_outputs['dense'] # [batch_size, dim]
|
|
passage_dense = passage_outputs['dense'] # [batch_size * num_passages, dim]
|
|
|
|
# Extract first passage for each query
|
|
positive_passages = passage_dense[::num_passages] # [batch_size, dim]
|
|
|
|
try:
|
|
dense_scores = self.model.compute_similarity(
|
|
{'dense': query_dense},
|
|
{'dense': positive_passages},
|
|
mode='dense'
|
|
)
|
|
scores.append(dense_scores)
|
|
weights.append(self.config.dense_weight)
|
|
except Exception as e:
|
|
logger.warning(f"Dense similarity computation failed: {e}")
|
|
|
|
# Sparse scores (use only first passage per query)
|
|
if 'sparse' in query_outputs and 'sparse' in passage_outputs:
|
|
query_sparse = query_outputs['sparse'] # [batch_size, query_seq_len]
|
|
passage_sparse = passage_outputs['sparse'] # [batch_size * num_passages, doc_seq_len]
|
|
|
|
# Extract first passage for each query
|
|
positive_passages = passage_sparse[::num_passages] # [batch_size, doc_seq_len]
|
|
|
|
try:
|
|
sparse_scores = self.model.compute_similarity(
|
|
{'sparse': query_sparse},
|
|
{'sparse': positive_passages},
|
|
mode='sparse'
|
|
)
|
|
scores.append(sparse_scores)
|
|
weights.append(self.config.sparse_weight)
|
|
except Exception as e:
|
|
logger.warning(f"Sparse similarity computation failed: {e}")
|
|
|
|
# ColBERT scores (use only first passage per query)
|
|
if 'colbert' in query_outputs and 'colbert' in passage_outputs:
|
|
query_colbert = query_outputs['colbert'] # [batch_size, query_seq_len, dim]
|
|
passage_colbert = passage_outputs['colbert'] # [batch_size * num_passages, doc_seq_len, dim]
|
|
|
|
# Extract first passage for each query
|
|
positive_passages = passage_colbert[::num_passages] # [batch_size, doc_seq_len, dim]
|
|
|
|
try:
|
|
colbert_scores = self.model.compute_similarity(
|
|
{'colbert': query_colbert},
|
|
{'colbert': positive_passages},
|
|
mode='colbert'
|
|
)
|
|
scores.append(colbert_scores)
|
|
weights.append(self.config.colbert_weight)
|
|
except Exception as e:
|
|
logger.warning(f"ColBERT similarity computation failed: {e}")
|
|
else:
|
|
# Single passage per query - original logic
|
|
# Dense scores
|
|
if 'dense' in query_outputs:
|
|
try:
|
|
dense_scores = self.model.compute_similarity(
|
|
{'dense': query_outputs['dense']},
|
|
{'dense': passage_outputs['dense']},
|
|
mode='dense'
|
|
)
|
|
scores.append(dense_scores)
|
|
weights.append(self.config.dense_weight)
|
|
except Exception as e:
|
|
logger.warning(f"Dense similarity computation failed: {e}")
|
|
|
|
# Sparse scores
|
|
if 'sparse' in query_outputs:
|
|
try:
|
|
sparse_scores = self.model.compute_similarity(
|
|
{'sparse': query_outputs['sparse']},
|
|
{'sparse': passage_outputs['sparse']},
|
|
mode='sparse'
|
|
)
|
|
scores.append(sparse_scores)
|
|
weights.append(self.config.sparse_weight)
|
|
except Exception as e:
|
|
logger.warning(f"Sparse similarity computation failed: {e}")
|
|
|
|
# ColBERT scores
|
|
if 'colbert' in query_outputs:
|
|
try:
|
|
colbert_scores = self.model.compute_similarity(
|
|
{'colbert': query_outputs['colbert']},
|
|
{'colbert': passage_outputs['colbert']},
|
|
mode='colbert'
|
|
)
|
|
scores.append(colbert_scores)
|
|
weights.append(self.config.colbert_weight)
|
|
except Exception as e:
|
|
logger.warning(f"ColBERT similarity computation failed: {e}")
|
|
|
|
# Weighted combination
|
|
if scores:
|
|
weights = torch.tensor(weights, device=scores[0].device)
|
|
weights = weights / weights.sum()
|
|
teacher_scores = torch.stack([w * s for w, s in zip(weights, scores)]).sum(dim=0)
|
|
return teacher_scores
|
|
|
|
# Fallback to zero scores if no valid similarity computation
|
|
fallback_size = query_batch_size if not multiple_passages else query_batch_size
|
|
return torch.zeros(fallback_size, device=query_outputs[list(query_outputs.keys())[0]].device)
|
|
|
|
def evaluate_step(self, batch: Dict[str, torch.Tensor]) -> Dict[str, float]:
|
|
"""Evaluate a single batch"""
|
|
# Get query embeddings
|
|
query_outputs = self.model.forward(
|
|
input_ids=batch['query_input_ids'],
|
|
attention_mask=batch['query_attention_mask'],
|
|
return_dense=True,
|
|
return_dict=True
|
|
)
|
|
|
|
# Get passage embeddings
|
|
passage_outputs = self.model.forward(
|
|
input_ids=batch['passage_input_ids'],
|
|
attention_mask=batch['passage_attention_mask'],
|
|
return_dense=True,
|
|
return_dict=True
|
|
)
|
|
|
|
# Compute loss
|
|
loss = self.compute_loss(batch)
|
|
|
|
# Compute similarity scores
|
|
scores = self.model.compute_similarity(
|
|
query_outputs, passage_outputs, mode='dense'
|
|
)
|
|
|
|
# Basic metrics
|
|
labels = batch.get('labels', torch.zeros(scores.shape[0]))
|
|
if labels.dim() > 1:
|
|
labels = labels.view(-1)
|
|
|
|
predictions = (scores > 0.5).float()
|
|
accuracy = (predictions == labels.float()).float().mean()
|
|
|
|
return {
|
|
'loss': loss.item(),
|
|
'accuracy': accuracy.item(),
|
|
'mean_score': scores.mean().item()
|
|
}
|
|
|
|
def train(
|
|
self,
|
|
train_dataset: BGEM3Dataset,
|
|
eval_dataset: Optional[BGEM3Dataset] = None,
|
|
resume_from_checkpoint: Optional[str] = None
|
|
):
|
|
"""
|
|
Main training loop for M3Trainer with robust error handling, config-driven parameters, and progress tracking.
|
|
"""
|
|
try:
|
|
train_dataloader = self._create_dataloader(train_dataset, is_train=True)
|
|
eval_dataloader = self._create_dataloader(eval_dataset, is_train=False) if eval_dataset else None
|
|
|
|
# Calculate total training steps
|
|
num_update_steps_per_epoch = max(1, len(train_dataloader) // self.config.gradient_accumulation_steps) # type: ignore[arg-type]
|
|
total_training_steps = num_update_steps_per_epoch * self.config.num_train_epochs
|
|
|
|
# Create scheduler with total steps
|
|
self.create_scheduler(total_training_steps)
|
|
|
|
if resume_from_checkpoint:
|
|
self.load_checkpoint(resume_from_checkpoint)
|
|
|
|
logger.info(f"Starting training for {self.config.num_train_epochs} epochs")
|
|
logger.info(f"Total training steps: {total_training_steps}")
|
|
|
|
# Use parent's train method
|
|
super().train(train_dataloader, eval_dataloader, self.config.num_train_epochs) # type: ignore[arg-type]
|
|
|
|
except Exception as e:
|
|
logger.error(f"M3 training failed: {e}")
|
|
raise
|
|
finally:
|
|
# Cleanup
|
|
if hasattr(self, 'hard_negative_miner') and self.hard_negative_miner:
|
|
self.hard_negative_miner.stop_async_updates()
|
|
|
|
def _create_dataloader(
|
|
self,
|
|
dataset: Optional[BGEM3Dataset],
|
|
is_train: bool = True
|
|
) -> Optional[DataLoader]:
|
|
"""Create data loader"""
|
|
if dataset is None:
|
|
return None
|
|
|
|
# Import both collate functions
|
|
from data.dataset import collate_embedding_batch
|
|
|
|
# Choose appropriate collate function based on dataset type
|
|
collate_fn = collate_embedding_batch
|
|
num_workers = self.config.dataloader_num_workers
|
|
|
|
batch_size = self.config.per_device_train_batch_size if is_train else self.config.per_device_eval_batch_size
|
|
|
|
# Setup distributed sampler if needed
|
|
sampler = None
|
|
if self.is_distributed:
|
|
from torch.utils.data.distributed import DistributedSampler
|
|
sampler = DistributedSampler(dataset, shuffle=is_train)
|
|
|
|
dataloader = DataLoader(
|
|
dataset,
|
|
batch_size=batch_size,
|
|
shuffle=is_train and sampler is None,
|
|
sampler=sampler,
|
|
num_workers=num_workers,
|
|
pin_memory=self.config.dataloader_pin_memory,
|
|
drop_last=is_train and len(dataset) > batch_size, # Only drop last if we have multiple batches
|
|
collate_fn=collate_fn
|
|
)
|
|
|
|
return dataloader
|
|
|
|
def _train_epoch(self, dataloader: DataLoader, epoch: int) -> Dict[str, float]:
|
|
"""Train one epoch - use parent's method with hard negative mining integration"""
|
|
# Add hard negative mining updates before calling parent's method
|
|
if self.hard_negative_miner:
|
|
# Set up a hook to update hard negatives during training
|
|
def mining_hook():
|
|
if self.hard_negative_miner:
|
|
self.hard_negative_miner.update_index_if_needed(self.global_step, self.model)
|
|
|
|
# Store the hook for use in compute_loss
|
|
self._mining_hook = mining_hook
|
|
|
|
# Use parent's robust training loop
|
|
return super().train_epoch(dataloader, epoch)
|
|
|
|
def evaluate(self, dataloader: DataLoader) -> Dict[str, float]:
|
|
"""Evaluate the model"""
|
|
self.model.eval()
|
|
|
|
all_query_embeddings = []
|
|
all_passage_embeddings = []
|
|
all_labels = []
|
|
|
|
with torch.no_grad():
|
|
with tqdm(dataloader, desc="Evaluating") as tbar:
|
|
for batch in tbar:
|
|
try:
|
|
batch = self._prepare_batch(batch)
|
|
|
|
query_outputs = self.model.forward(
|
|
input_ids=batch['query_input_ids'],
|
|
attention_mask=batch['query_attention_mask'],
|
|
return_dense=True,
|
|
return_dict=True
|
|
)
|
|
|
|
passage_outputs = self.model.forward(
|
|
input_ids=batch['passage_input_ids'],
|
|
attention_mask=batch['passage_attention_mask'],
|
|
return_dense=True,
|
|
return_dict=True
|
|
)
|
|
|
|
all_query_embeddings.append(query_outputs['dense'].cpu()) # type: ignore[index]
|
|
all_passage_embeddings.append(passage_outputs['dense'].cpu()) # type: ignore[index]
|
|
all_labels.append(batch['labels'].cpu())
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in M3 evaluation step: {e}")
|
|
continue
|
|
|
|
if not all_query_embeddings:
|
|
logger.warning("No valid evaluation batches processed")
|
|
return {'loss': float('inf'), 'accuracy': 0.0}
|
|
|
|
# Concatenate all embeddings
|
|
all_query_embeddings = torch.cat(all_query_embeddings, dim=0)
|
|
all_passage_embeddings = torch.cat(all_passage_embeddings, dim=0)
|
|
all_labels = torch.cat(all_labels, dim=0)
|
|
|
|
# Compute metrics
|
|
try:
|
|
metrics = compute_retrieval_metrics(
|
|
all_query_embeddings.numpy(),
|
|
all_passage_embeddings.numpy(),
|
|
all_labels.numpy(),
|
|
k_values=[1, 3, 5, 10]
|
|
)
|
|
except Exception as e:
|
|
logger.error(f"Error computing retrieval metrics: {e}")
|
|
metrics = {'loss': float('inf'), 'accuracy': 0.0}
|
|
|
|
return metrics
|
|
|
|
def save_model(self, save_path: str):
|
|
"""Save the trained model"""
|
|
try:
|
|
os.makedirs(save_path, exist_ok=True)
|
|
|
|
# Save model
|
|
model_to_save = self.model.module if hasattr(self.model, 'module') else self.model
|
|
model_to_save.save_pretrained(save_path) # type: ignore[attr-defined]
|
|
|
|
# Save tokenizer
|
|
self.tokenizer.save_pretrained(save_path) # type: ignore[attr-defined]
|
|
|
|
# Save config
|
|
self.config.save(os.path.join(save_path, 'training_config.json'))
|
|
|
|
logger.info(f"Model saved to {save_path}")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error saving model: {e}")
|
|
raise
|
|
|
|
def __del__(self):
|
|
"""Cleanup when trainer is destroyed"""
|
|
if hasattr(self, 'hard_negative_miner') and self.hard_negative_miner:
|
|
self.hard_negative_miner.stop_async_updates()
|