bge_finetune/training/m3_trainer.py
2025-07-22 16:55:25 +08:00

895 lines
38 KiB
Python

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.optim import AdamW
from transformers import AutoTokenizer
import os
import json
import logging
from tqdm import tqdm
from typing import Dict, Optional, List, Tuple
import numpy as np
from torch.cuda.amp.autocast_mode import autocast
from torch.cuda.amp.grad_scaler import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP
import torch.distributed as dist
from .trainer import BaseTrainer
from models.bge_m3 import BGEM3Model
from models.losses import InfoNCELoss, M3KnowledgeDistillationLoss, MultiTaskLoss
from data.dataset import BGEM3Dataset
# Optional hard negative mining
try:
from data.hard_negative_mining import ANCEHardNegativeMiner
ANCE_AVAILABLE = True
except ImportError:
ANCE_AVAILABLE = False
ANCEHardNegativeMiner = None
from utils.checkpoint import CheckpointManager
from utils.logging import get_logger
from evaluation.metrics import compute_retrieval_metrics
from config.m3_config import BGEM3Config
# Import scheduler with fallback
try:
from transformers import get_linear_schedule_with_warmup
except ImportError:
get_linear_schedule_with_warmup = None
logger = get_logger(__name__)
class BGEM3Trainer(BaseTrainer):
"""
Trainer for BGE-M3 embedding model with all SOTA techniques
Notes:
- All parameters (config, model, tokenizer, etc.) should be passed from config-driven training scripts.
- This class does not load config directly; pass config values from your main script for consistency.
"""
def __init__(
self,
config: BGEM3Config,
model: Optional[BGEM3Model] = None,
tokenizer: Optional[AutoTokenizer] = None,
optimizer: Optional[torch.optim.Optimizer] = None,
scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
device: Optional[torch.device] = None
):
# Store config first
self.config = config
# Initialize tokenizer first
if tokenizer is None:
self.tokenizer = AutoTokenizer.from_pretrained(
config.model_name_or_path,
cache_dir=config.cache_dir,
trust_remote_code=True
)
else:
self.tokenizer = tokenizer
# Initialize model
if model is None:
self.model = BGEM3Model(
model_name_or_path=config.model_name_or_path,
use_dense=config.use_dense,
use_sparse=config.use_sparse,
use_colbert=config.use_colbert,
pooling_method=config.sentence_pooling_method,
normalize_embeddings=config.normalize_embeddings,
colbert_dim=config.colbert_dim,
sparse_top_k=config.sparse_top_k,
cache_dir=config.cache_dir
)
else:
self.model = model
# Initialize optimizer if not provided
if optimizer is None:
optimizer = self._create_optimizer()
# Initialize scheduler if not provided and get_linear_schedule_with_warmup is available
if scheduler is None and get_linear_schedule_with_warmup is not None:
# We'll create it later when we know the total training steps
pass
# Initialize parent class
super().__init__(config, self.model, optimizer, scheduler, device)
# Initialize losses
self.infonce_loss = InfoNCELoss(
temperature=config.temperature,
reduction='mean'
)
self.kd_loss = M3KnowledgeDistillationLoss(
temperature=config.self_distill_temperature,
alpha=config.self_distill_alpha,
distill_types=['dense', 'sparse', 'colbert']
)
self.multi_task_loss = MultiTaskLoss(
loss_weights={
'dense': config.dense_weight,
'sparse': config.sparse_weight,
'colbert': config.colbert_weight,
'distillation': config.distillation_weight
}
)
# Initialize hard negative miner
if config.use_hard_negatives and ANCE_AVAILABLE:
self.hard_negative_miner = ANCEHardNegativeMiner(
embedding_dim=self.model.config.hidden_size, # type: ignore[attr-defined]
num_hard_negatives=config.num_hard_negatives,
negative_range=config.ance_negative_range,
margin=config.ance_margin,
index_refresh_interval=config.hard_negative_mining_steps,
use_gpu=str(self.device) == "cuda"
)
self.hard_negative_miner.start_async_updates()
else:
if config.use_hard_negatives and not ANCE_AVAILABLE:
logger.warning("Hard negative mining requested but ANCEHardNegativeMiner not available")
self.hard_negative_miner = None
logger.info(f"Initialized BGE-M3 trainer with config: {config.to_dict()}")
def _create_optimizer(self) -> torch.optim.Optimizer:
"""Create optimizer with parameter groups"""
# Separate parameters for different learning rates if needed
param_groups = [
{
'params': [p for n, p in self.model.named_parameters()
if 'colbert_linear' not in n and 'sparse_linear' not in n],
'lr': self.config.learning_rate
}
]
# Different LR for task-specific layers
if self.config.use_colbert and hasattr(self.model, 'colbert_linear'):
param_groups.append({
'params': self.model.colbert_linear.parameters(),
'lr': self.config.learning_rate * 2 # Higher LR for new layers
})
if self.config.use_sparse and hasattr(self.model, 'sparse_linear'):
param_groups.append({
'params': self.model.sparse_linear.parameters(),
'lr': self.config.learning_rate * 2
})
optimizer = AdamW(
param_groups,
eps=self.config.adam_epsilon,
weight_decay=self.config.weight_decay,
betas=(self.config.adam_beta1, self.config.adam_beta2)
)
return optimizer
def create_scheduler(self, num_training_steps: int):
"""Create learning rate scheduler"""
if get_linear_schedule_with_warmup is None:
logger.warning("transformers not available, using constant learning rate")
return None
num_warmup_steps = self.config.warmup_steps or int(num_training_steps * self.config.warmup_ratio)
self.scheduler = get_linear_schedule_with_warmup(
self.optimizer,
num_warmup_steps=num_warmup_steps,
num_training_steps=num_training_steps
)
return self.scheduler
def compute_loss(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
"""Compute M3 loss with multiple objectives"""
# Update hard negative mining if hook is available
if hasattr(self, '_mining_hook'):
self._mining_hook()
# Detect batch format and handle accordingly
if 'pos_input_ids' in batch and 'neg_input_ids' in batch:
# This is triplet format from FastM3Dataset
return self._compute_triplet_loss(batch)
elif 'passage_input_ids' in batch:
# This is standard embedding format
return self._compute_embedding_loss(batch)
else:
raise ValueError(f"Unknown batch format. Keys: {list(batch.keys())}")
def _compute_triplet_loss(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
"""Compute loss for triplet format data (FastM3Dataset)"""
# Get query embeddings
query_outputs = self.model.forward(
input_ids=batch['query_input_ids'],
attention_mask=batch['query_attention_mask'],
return_dense=self.config.use_dense,
return_sparse=self.config.use_sparse,
return_colbert=self.config.use_colbert,
return_dict=True
)
# Get positive embeddings
pos_outputs = self.model.forward(
input_ids=batch['pos_input_ids'],
attention_mask=batch['pos_attention_mask'],
return_dense=self.config.use_dense,
return_sparse=self.config.use_sparse,
return_colbert=self.config.use_colbert,
return_dict=True
)
# Get negative embeddings
neg_outputs = self.model.forward(
input_ids=batch['neg_input_ids'],
attention_mask=batch['neg_attention_mask'],
return_dense=self.config.use_dense,
return_sparse=self.config.use_sparse,
return_colbert=self.config.use_colbert,
return_dict=True
)
losses = {}
# Dense loss (InfoNCE with explicit pos/neg)
if 'dense' in query_outputs:
query_dense = query_outputs['dense'].float() # type: ignore[index]
pos_dense = pos_outputs['dense'].float() # type: ignore[index]
neg_dense = neg_outputs['dense'].float() # type: ignore[index]
# For triplet format, reshape negative embeddings to expected 3D format
# InfoNCE expects negative_embeddings: [batch_size, num_negatives, embedding_dim]
# But we have single negatives: [batch_size, embedding_dim]
# Reshape to [batch_size, 1, embedding_dim]
neg_dense_3d = neg_dense.unsqueeze(1) # [batch_size, 1, embedding_dim]
dense_loss = self.infonce_loss(
query_dense,
pos_dense,
neg_dense_3d # Now properly shaped 3D tensor
)
losses['dense'] = dense_loss
# Sparse loss
if self.config.use_sparse and 'sparse' in query_outputs:
from models.losses import SparseLoss
sparse_loss_fn = SparseLoss(
flops_loss_weight=1e-3,
sparse_reg_weight=1e-4
)
# Combine pos and neg for sparse loss
combined_sparse = torch.cat([pos_outputs['sparse'], neg_outputs['sparse']], dim=0) # type: ignore[index]
query_sparse_repeated = query_outputs['sparse'].repeat(2, 1) # type: ignore[index]
labels = torch.cat([torch.ones(len(pos_outputs['sparse'])), # type: ignore[index]
torch.zeros(len(neg_outputs['sparse']))], dim=0).to(query_sparse_repeated.device) # type: ignore[index]
sparse_loss, _ = sparse_loss_fn(
query_sparse_repeated,
combined_sparse,
labels
)
losses['sparse'] = sparse_loss
# ColBERT loss
if self.config.use_colbert and 'colbert' in query_outputs:
from models.losses import ColBERTLoss
colbert_loss_fn = ColBERTLoss(temperature=self.config.temperature)
# Combine pos and neg for ColBERT loss
combined_colbert = torch.cat([pos_outputs['colbert'], neg_outputs['colbert']], dim=0) # type: ignore[index]
query_colbert_repeated = query_outputs['colbert'].repeat(2, 1, 1) # type: ignore[index]
labels = torch.cat([torch.ones(len(pos_outputs['colbert'])), # type: ignore[index]
torch.zeros(len(neg_outputs['colbert']))], dim=0).to(query_colbert_repeated.device) # type: ignore[index]
colbert_loss = colbert_loss_fn(
query_colbert_repeated,
combined_colbert,
labels
)
losses['colbert'] = colbert_loss
# Combine losses
total_loss, _ = self.multi_task_loss(losses)
return total_loss
def _compute_embedding_loss(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
"""Compute loss for standard embedding format data"""
# Get query embeddings
query_outputs = self.model.forward(
input_ids=batch['query_input_ids'],
attention_mask=batch['query_attention_mask'],
return_dense=self.config.use_dense,
return_sparse=self.config.use_sparse,
return_colbert=self.config.use_colbert,
return_dict=True
)
# Get passage embeddings - handle both single and multiple passages
if batch['passage_input_ids'].dim() == 3:
# Multiple passages per query [batch, num_passages, seq_len]
batch_size, num_passages, seq_len = batch['passage_input_ids'].shape
passage_input_ids = batch['passage_input_ids'].view(batch_size * num_passages, seq_len)
passage_attention_mask = batch['passage_attention_mask'].view(batch_size * num_passages, seq_len)
multiple_passages = True
else:
# Single passage per query [batch, seq_len]
passage_input_ids = batch['passage_input_ids']
passage_attention_mask = batch['passage_attention_mask']
batch_size = passage_input_ids.shape[0]
num_passages = 1
multiple_passages = False
passage_outputs = self.model.forward(
input_ids=passage_input_ids,
attention_mask=passage_attention_mask,
return_dense=True,
return_sparse=self.config.use_sparse,
return_colbert=self.config.use_colbert,
return_dict=True
)
# Compute losses for different components
losses = {}
# Dense loss (InfoNCE)
if 'dense' in query_outputs and 'dense' in passage_outputs:
# Get embeddings and ensure they're on the same device
query_dense = query_outputs['dense'].float() # type: ignore[index] # [batch_size, embedding_dim]
passage_dense = passage_outputs['dense'].float() # type: ignore[index] # [batch_size * num_passages, embedding_dim]
# Handle multiple passages per query differently for contrastive learning
if multiple_passages:
# Reshape passage embeddings back to [batch_size, num_passages, embedding_dim]
passage_dense = passage_dense.view(batch_size, num_passages, -1)
# For contrastive learning, we need to identify positive passages
# Use the labels to find the first positive passage for each query
labels = batch.get('labels') # [batch_size, num_passages]
if labels is not None:
# Find the first positive passage for each query
positive_indices = []
negative_passages = []
for i in range(batch_size):
query_labels = labels[i] # [num_passages]
positive_mask = query_labels > 0
if positive_mask.any():
# Get first positive passage
pos_idx = positive_mask.nonzero(as_tuple=True)[0][0].item()
positive_indices.append(pos_idx)
# Collect negative passages
neg_mask = ~positive_mask
if neg_mask.any():
neg_embeddings = passage_dense[i][neg_mask] # [num_negatives, embedding_dim]
negative_passages.append(neg_embeddings)
else:
negative_passages.append(None)
else:
# No positive passage, use first passage as positive (fallback)
positive_indices.append(0)
negative_passages.append(passage_dense[i][1:])
# Extract positive embeddings
positive_embeddings = torch.stack([
passage_dense[i, positive_indices[i]] for i in range(batch_size)
]) # [batch_size, embedding_dim]
# Stack negative embeddings (if available)
max_neg = max(neg.shape[0] if neg is not None else 0 for neg in negative_passages)
if max_neg > 0:
# Pad negative embeddings to same size
padded_negatives = []
for neg in negative_passages:
if neg is not None and neg.shape[0] > 0:
if neg.shape[0] < max_neg:
# Pad with zeros
padding = torch.zeros(max_neg - neg.shape[0], neg.shape[1], device=neg.device)
neg = torch.cat([neg, padding], dim=0)
padded_negatives.append(neg)
else:
# No negatives, create zero tensor
padded_negatives.append(torch.zeros(max_neg, passage_dense.shape[-1], device=passage_dense.device))
negative_embeddings = torch.stack(padded_negatives) # [batch_size, max_neg, embedding_dim]
else:
negative_embeddings = None
# Compute InfoNCE loss with explicit positives and negatives
dense_loss = self.infonce_loss(
query_dense,
positive_embeddings,
negative_embeddings
)
else:
# No labels provided, use first passage as positive, rest as negatives
positive_embeddings = passage_dense[:, 0, :] # [batch_size, embedding_dim]
if num_passages > 1:
negative_embeddings = passage_dense[:, 1:, :] # [batch_size, num_passages-1, embedding_dim]
else:
negative_embeddings = None
dense_loss = self.infonce_loss(
query_dense,
positive_embeddings,
negative_embeddings
)
else:
# Single passage per query - use standard in-batch negatives
# InfoNCE will create contrastive pairs using in-batch negatives
dense_loss = self.infonce_loss(
query_dense,
passage_dense,
None # No explicit negatives, will use in-batch
)
losses['dense'] = dense_loss
# Sparse loss (if using sparse)
if self.config.use_sparse and 'sparse' in query_outputs and 'sparse' in passage_outputs:
from models.losses import SparseLoss
sparse_loss_fn = SparseLoss(
flops_loss_weight=1e-3,
sparse_reg_weight=1e-4
)
query_sparse = query_outputs['sparse'] # type: ignore[index] # [batch_size, query_seq_len]
passage_sparse = passage_outputs['sparse'] # type: ignore[index] # [batch_size * num_passages, doc_seq_len]
# For sparse loss, we need to handle the contrastive learning differently
# Since sparse embeddings have different sequence lengths, we'll use the same approach as dense
if multiple_passages:
# Use only the first positive passage for each query
labels = batch.get('labels')
if labels is not None:
positive_indices = []
for i in range(batch_size):
query_labels = labels[i]
positive_mask = query_labels > 0
if positive_mask.any():
pos_idx = positive_mask.nonzero(as_tuple=True)[0][0].item()
else:
pos_idx = 0
positive_indices.append(pos_idx + i * num_passages)
# Extract positive sparse embeddings
positive_sparse = passage_sparse[positive_indices] # [batch_size, doc_seq_len]
else:
# Use first passage as positive
positive_sparse = passage_sparse[::num_passages] # [batch_size, doc_seq_len]
# Create contrastive labels for sparse loss
contrastive_labels = torch.arange(batch_size, device=query_sparse.device, dtype=torch.long)
sparse_loss, _ = sparse_loss_fn(
query_sparse,
positive_sparse,
contrastive_labels
)
else:
# Single passage per query
contrastive_labels = torch.arange(batch_size, device=query_sparse.device, dtype=torch.long)
sparse_loss, _ = sparse_loss_fn(
query_sparse,
passage_sparse,
contrastive_labels
)
losses['sparse'] = sparse_loss
# ColBERT loss (if using ColBERT)
if self.config.use_colbert and 'colbert' in query_outputs and 'colbert' in passage_outputs:
from models.losses import ColBERTLoss
colbert_loss_fn = ColBERTLoss(temperature=self.config.temperature)
query_colbert = query_outputs['colbert'] # type: ignore[index] # [batch_size, query_seq_len, dim]
passage_colbert = passage_outputs['colbert'] # type: ignore[index] # [batch_size * num_passages, doc_seq_len, dim]
# For ColBERT, we need pairwise interactions
if multiple_passages:
# Use only positive passages
labels = batch.get('labels')
if labels is not None:
positive_indices = []
for i in range(batch_size):
query_labels = labels[i]
positive_mask = query_labels > 0
if positive_mask.any():
pos_idx = positive_mask.nonzero(as_tuple=True)[0][0].item()
else:
pos_idx = 0
positive_indices.append(pos_idx + i * num_passages)
positive_colbert = passage_colbert[positive_indices]
else:
positive_colbert = passage_colbert[::num_passages]
# Binary labels for ColBERT (1 for positive pairs)
colbert_labels = torch.ones(batch_size, device=query_colbert.device)
else:
positive_colbert = passage_colbert
colbert_labels = torch.ones(batch_size, device=query_colbert.device)
colbert_loss = colbert_loss_fn(
query_colbert,
positive_colbert,
colbert_labels
)
losses['colbert'] = colbert_loss
# Knowledge distillation loss (only if explicitly enabled)
if (getattr(self.config, 'use_self_distill', False) and
self.global_step >= getattr(self.config, 'self_distill_start_step', 0)):
try:
# Compute teacher scores (weighted combination)
teacher_scores = self._compute_teacher_scores(query_outputs, passage_outputs) # type: ignore[arg-type]
student_scores = self.model.compute_similarity(
query_outputs, passage_outputs, mode='dense'
)
kd_loss, _ = self.kd_loss(
{'dense': student_scores},
{'dense': teacher_scores}
)
losses['kd'] = kd_loss
except Exception as e:
logger.warning(f"Knowledge distillation failed: {e}")
# Combine losses
total_loss, _ = self.multi_task_loss(losses)
return total_loss
def _compute_teacher_scores(
self,
query_outputs: Dict[str, torch.Tensor],
passage_outputs: Dict[str, torch.Tensor]
) -> torch.Tensor:
"""Compute teacher scores from multiple signals"""
scores = []
weights = []
# Get original batch size for reshaping
query_batch_size = query_outputs['dense'].shape[0] if 'dense' in query_outputs else 1
passage_batch_size = passage_outputs['dense'].shape[0] if 'dense' in passage_outputs else 1
# Check if we have multiple passages per query
multiple_passages = passage_batch_size > query_batch_size
if multiple_passages:
# For multiple passages, we need to compute scores only for positive pairs
# Use only the first passage for each query as a simplified teacher score
num_passages = passage_batch_size // query_batch_size
# Dense scores (use only first passage per query)
if 'dense' in query_outputs and 'dense' in passage_outputs:
query_dense = query_outputs['dense'] # [batch_size, dim]
passage_dense = passage_outputs['dense'] # [batch_size * num_passages, dim]
# Extract first passage for each query
positive_passages = passage_dense[::num_passages] # [batch_size, dim]
try:
dense_scores = self.model.compute_similarity(
{'dense': query_dense},
{'dense': positive_passages},
mode='dense'
)
scores.append(dense_scores)
weights.append(self.config.dense_weight)
except Exception as e:
logger.warning(f"Dense similarity computation failed: {e}")
# Sparse scores (use only first passage per query)
if 'sparse' in query_outputs and 'sparse' in passage_outputs:
query_sparse = query_outputs['sparse'] # [batch_size, query_seq_len]
passage_sparse = passage_outputs['sparse'] # [batch_size * num_passages, doc_seq_len]
# Extract first passage for each query
positive_passages = passage_sparse[::num_passages] # [batch_size, doc_seq_len]
try:
sparse_scores = self.model.compute_similarity(
{'sparse': query_sparse},
{'sparse': positive_passages},
mode='sparse'
)
scores.append(sparse_scores)
weights.append(self.config.sparse_weight)
except Exception as e:
logger.warning(f"Sparse similarity computation failed: {e}")
# ColBERT scores (use only first passage per query)
if 'colbert' in query_outputs and 'colbert' in passage_outputs:
query_colbert = query_outputs['colbert'] # [batch_size, query_seq_len, dim]
passage_colbert = passage_outputs['colbert'] # [batch_size * num_passages, doc_seq_len, dim]
# Extract first passage for each query
positive_passages = passage_colbert[::num_passages] # [batch_size, doc_seq_len, dim]
try:
colbert_scores = self.model.compute_similarity(
{'colbert': query_colbert},
{'colbert': positive_passages},
mode='colbert'
)
scores.append(colbert_scores)
weights.append(self.config.colbert_weight)
except Exception as e:
logger.warning(f"ColBERT similarity computation failed: {e}")
else:
# Single passage per query - original logic
# Dense scores
if 'dense' in query_outputs:
try:
dense_scores = self.model.compute_similarity(
{'dense': query_outputs['dense']},
{'dense': passage_outputs['dense']},
mode='dense'
)
scores.append(dense_scores)
weights.append(self.config.dense_weight)
except Exception as e:
logger.warning(f"Dense similarity computation failed: {e}")
# Sparse scores
if 'sparse' in query_outputs:
try:
sparse_scores = self.model.compute_similarity(
{'sparse': query_outputs['sparse']},
{'sparse': passage_outputs['sparse']},
mode='sparse'
)
scores.append(sparse_scores)
weights.append(self.config.sparse_weight)
except Exception as e:
logger.warning(f"Sparse similarity computation failed: {e}")
# ColBERT scores
if 'colbert' in query_outputs:
try:
colbert_scores = self.model.compute_similarity(
{'colbert': query_outputs['colbert']},
{'colbert': passage_outputs['colbert']},
mode='colbert'
)
scores.append(colbert_scores)
weights.append(self.config.colbert_weight)
except Exception as e:
logger.warning(f"ColBERT similarity computation failed: {e}")
# Weighted combination
if scores:
weights = torch.tensor(weights, device=scores[0].device)
weights = weights / weights.sum()
teacher_scores = torch.stack([w * s for w, s in zip(weights, scores)]).sum(dim=0)
return teacher_scores
# Fallback to zero scores if no valid similarity computation
fallback_size = query_batch_size if not multiple_passages else query_batch_size
return torch.zeros(fallback_size, device=query_outputs[list(query_outputs.keys())[0]].device)
def evaluate_step(self, batch: Dict[str, torch.Tensor]) -> Dict[str, float]:
"""Evaluate a single batch"""
# Get query embeddings
query_outputs = self.model.forward(
input_ids=batch['query_input_ids'],
attention_mask=batch['query_attention_mask'],
return_dense=True,
return_dict=True
)
# Get passage embeddings
passage_outputs = self.model.forward(
input_ids=batch['passage_input_ids'],
attention_mask=batch['passage_attention_mask'],
return_dense=True,
return_dict=True
)
# Compute loss
loss = self.compute_loss(batch)
# Compute similarity scores
scores = self.model.compute_similarity(
query_outputs, passage_outputs, mode='dense'
)
# Basic metrics
labels = batch.get('labels', torch.zeros(scores.shape[0]))
if labels.dim() > 1:
labels = labels.view(-1)
predictions = (scores > 0.5).float()
accuracy = (predictions == labels.float()).float().mean()
return {
'loss': loss.item(),
'accuracy': accuracy.item(),
'mean_score': scores.mean().item()
}
def train(
self,
train_dataset: BGEM3Dataset,
eval_dataset: Optional[BGEM3Dataset] = None,
resume_from_checkpoint: Optional[str] = None
):
"""
Main training loop for M3Trainer with robust error handling, config-driven parameters, and progress tracking.
"""
try:
train_dataloader = self._create_dataloader(train_dataset, is_train=True)
eval_dataloader = self._create_dataloader(eval_dataset, is_train=False) if eval_dataset else None
# Calculate total training steps
num_update_steps_per_epoch = max(1, len(train_dataloader) // self.config.gradient_accumulation_steps) # type: ignore[arg-type]
total_training_steps = num_update_steps_per_epoch * self.config.num_train_epochs
# Create scheduler with total steps
self.create_scheduler(total_training_steps)
if resume_from_checkpoint:
self.load_checkpoint(resume_from_checkpoint)
logger.info(f"Starting training for {self.config.num_train_epochs} epochs")
logger.info(f"Total training steps: {total_training_steps}")
# Use parent's train method
super().train(train_dataloader, eval_dataloader, self.config.num_train_epochs) # type: ignore[arg-type]
except Exception as e:
logger.error(f"M3 training failed: {e}")
raise
finally:
# Cleanup
if hasattr(self, 'hard_negative_miner') and self.hard_negative_miner:
self.hard_negative_miner.stop_async_updates()
def _create_dataloader(
self,
dataset: Optional[BGEM3Dataset],
is_train: bool = True
) -> Optional[DataLoader]:
"""Create data loader"""
if dataset is None:
return None
# Import both collate functions
from data.dataset import collate_embedding_batch
# Choose appropriate collate function based on dataset type
collate_fn = collate_embedding_batch
num_workers = self.config.dataloader_num_workers
batch_size = self.config.per_device_train_batch_size if is_train else self.config.per_device_eval_batch_size
# Setup distributed sampler if needed
sampler = None
if self.is_distributed:
from torch.utils.data.distributed import DistributedSampler
sampler = DistributedSampler(dataset, shuffle=is_train)
dataloader = DataLoader(
dataset,
batch_size=batch_size,
shuffle=is_train and sampler is None,
sampler=sampler,
num_workers=num_workers,
pin_memory=self.config.dataloader_pin_memory,
drop_last=is_train and len(dataset) > batch_size, # Only drop last if we have multiple batches
collate_fn=collate_fn
)
return dataloader
def _train_epoch(self, dataloader: DataLoader, epoch: int) -> Dict[str, float]:
"""Train one epoch - use parent's method with hard negative mining integration"""
# Add hard negative mining updates before calling parent's method
if self.hard_negative_miner:
# Set up a hook to update hard negatives during training
def mining_hook():
if self.hard_negative_miner:
self.hard_negative_miner.update_index_if_needed(self.global_step, self.model)
# Store the hook for use in compute_loss
self._mining_hook = mining_hook
# Use parent's robust training loop
return super().train_epoch(dataloader, epoch)
def evaluate(self, dataloader: DataLoader) -> Dict[str, float]:
"""Evaluate the model"""
self.model.eval()
all_query_embeddings = []
all_passage_embeddings = []
all_labels = []
with torch.no_grad():
with tqdm(dataloader, desc="Evaluating") as tbar:
for batch in tbar:
try:
batch = self._prepare_batch(batch)
query_outputs = self.model.forward(
input_ids=batch['query_input_ids'],
attention_mask=batch['query_attention_mask'],
return_dense=True,
return_dict=True
)
passage_outputs = self.model.forward(
input_ids=batch['passage_input_ids'],
attention_mask=batch['passage_attention_mask'],
return_dense=True,
return_dict=True
)
all_query_embeddings.append(query_outputs['dense'].cpu()) # type: ignore[index]
all_passage_embeddings.append(passage_outputs['dense'].cpu()) # type: ignore[index]
all_labels.append(batch['labels'].cpu())
except Exception as e:
logger.error(f"Error in M3 evaluation step: {e}")
continue
if not all_query_embeddings:
logger.warning("No valid evaluation batches processed")
return {'loss': float('inf'), 'accuracy': 0.0}
# Concatenate all embeddings
all_query_embeddings = torch.cat(all_query_embeddings, dim=0)
all_passage_embeddings = torch.cat(all_passage_embeddings, dim=0)
all_labels = torch.cat(all_labels, dim=0)
# Compute metrics
try:
metrics = compute_retrieval_metrics(
all_query_embeddings.numpy(),
all_passage_embeddings.numpy(),
all_labels.numpy(),
k_values=[1, 3, 5, 10]
)
except Exception as e:
logger.error(f"Error computing retrieval metrics: {e}")
metrics = {'loss': float('inf'), 'accuracy': 0.0}
return metrics
def save_model(self, save_path: str):
"""Save the trained model"""
try:
os.makedirs(save_path, exist_ok=True)
# Save model
model_to_save = self.model.module if hasattr(self.model, 'module') else self.model
model_to_save.save_pretrained(save_path) # type: ignore[attr-defined]
# Save tokenizer
self.tokenizer.save_pretrained(save_path) # type: ignore[attr-defined]
# Save config
self.config.save(os.path.join(save_path, 'training_config.json'))
logger.info(f"Model saved to {save_path}")
except Exception as e:
logger.error(f"Error saving model: {e}")
raise
def __del__(self):
"""Cleanup when trainer is destroyed"""
if hasattr(self, 'hard_negative_miner') and self.hard_negative_miner:
self.hard_negative_miner.stop_async_updates()