399 lines
14 KiB
Python
399 lines
14 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
BGE-Reranker Cross-Encoder Model Training Script
|
|
Uses the refactored dataset pipeline with flat pairs schema
|
|
"""
|
|
|
|
import os
|
|
import sys
|
|
import argparse
|
|
import logging
|
|
import torch
|
|
from torch.utils.data import DataLoader
|
|
from transformers import AutoTokenizer, get_linear_schedule_with_warmup
|
|
from transformers.optimization import AdamW
|
|
import torch.nn.functional as F
|
|
|
|
# Add parent directory to path
|
|
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
|
|
|
from data.dataset import create_reranker_dataset, collate_reranker_batch, migrate_nested_to_flat
|
|
from models.bge_reranker import BGERerankerModel
|
|
from training.reranker_trainer import BGERerankerTrainer
|
|
from utils.logging import setup_logging, get_logger
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
|
|
def parse_args():
|
|
"""Parse command line arguments for reranker training"""
|
|
parser = argparse.ArgumentParser(description="Train BGE-Reranker Cross-Encoder Model")
|
|
|
|
# Model arguments
|
|
parser.add_argument("--model_name_or_path", type=str, default="BAAI/bge-reranker-base",
|
|
help="Path to pretrained BGE-Reranker model")
|
|
parser.add_argument("--cache_dir", type=str, default="./cache/data",
|
|
help="Directory to cache downloaded models")
|
|
|
|
# Data arguments
|
|
parser.add_argument("--reranker_data", type=str, required=True,
|
|
help="Path to reranker training data (reranker_data.jsonl)")
|
|
parser.add_argument("--eval_data", type=str, default=None,
|
|
help="Path to evaluation data")
|
|
parser.add_argument("--legacy_data", type=str, default=None,
|
|
help="Path to legacy nested data (will be migrated)")
|
|
|
|
# Training arguments
|
|
parser.add_argument("--output_dir", type=str, default="./output/bge-reranker",
|
|
help="Directory to save the trained model")
|
|
parser.add_argument("--num_train_epochs", type=int, default=3,
|
|
help="Number of training epochs")
|
|
parser.add_argument("--per_device_train_batch_size", type=int, default=8,
|
|
help="Batch size per GPU/CPU for training")
|
|
parser.add_argument("--per_device_eval_batch_size", type=int, default=16,
|
|
help="Batch size per GPU/CPU for evaluation")
|
|
parser.add_argument("--learning_rate", type=float, default=6e-5,
|
|
help="Learning rate")
|
|
parser.add_argument("--weight_decay", type=float, default=0.01,
|
|
help="Weight decay")
|
|
parser.add_argument("--warmup_ratio", type=float, default=0.1,
|
|
help="Warmup ratio")
|
|
|
|
# Model specific arguments
|
|
parser.add_argument("--query_max_length", type=int, default=64,
|
|
help="Maximum query length")
|
|
parser.add_argument("--passage_max_length", type=int, default=448,
|
|
help="Maximum passage length")
|
|
parser.add_argument("--loss_type", type=str, default="cross_entropy",
|
|
choices=["cross_entropy", "binary_cross_entropy", "listwise_ce"],
|
|
help="Loss function type")
|
|
|
|
# Enhanced features
|
|
parser.add_argument("--use_score_weighting", action="store_true",
|
|
help="Use score weighting in loss calculation")
|
|
parser.add_argument("--label_smoothing", type=float, default=0.0,
|
|
help="Label smoothing for cross entropy loss")
|
|
|
|
# System arguments
|
|
parser.add_argument("--fp16", action="store_true",
|
|
help="Use 16-bit floating point precision")
|
|
parser.add_argument("--dataloader_num_workers", type=int, default=4,
|
|
help="Number of workers for data loading")
|
|
parser.add_argument("--seed", type=int, default=42,
|
|
help="Random seed")
|
|
parser.add_argument("--log_level", type=str, default="INFO",
|
|
choices=["DEBUG", "INFO", "WARNING", "ERROR"],
|
|
help="Logging level")
|
|
|
|
# Backward compatibility
|
|
parser.add_argument("--legacy_support", action="store_true", default=True,
|
|
help="Enable legacy nested format support")
|
|
parser.add_argument("--migrate_legacy", action="store_true",
|
|
help="Migrate legacy nested format to flat format")
|
|
|
|
return parser.parse_args()
|
|
|
|
|
|
def setup_model_and_tokenizer(args):
|
|
"""Setup BGE-Reranker model and tokenizer"""
|
|
logger.info(f"Loading BGE-Reranker model from {args.model_name_or_path}")
|
|
|
|
# Load tokenizer
|
|
tokenizer = AutoTokenizer.from_pretrained(
|
|
args.model_name_or_path,
|
|
cache_dir=args.cache_dir,
|
|
trust_remote_code=True
|
|
)
|
|
|
|
# Load model
|
|
model = BGERerankerModel(
|
|
model_name_or_path=args.model_name_or_path,
|
|
cache_dir=args.cache_dir
|
|
)
|
|
|
|
return model, tokenizer
|
|
|
|
|
|
def migrate_legacy_data(args):
|
|
"""Migrate legacy nested data to flat format if needed"""
|
|
if args.legacy_data and args.migrate_legacy:
|
|
logger.info("Migrating legacy nested data to flat format...")
|
|
|
|
# Create output file name
|
|
legacy_base = os.path.splitext(args.legacy_data)[0]
|
|
migrated_file = f"{legacy_base}_migrated_flat.jsonl"
|
|
|
|
# Migrate the data
|
|
migrate_nested_to_flat(args.legacy_data, migrated_file)
|
|
|
|
logger.info(f"✅ Migrated legacy data to {migrated_file}")
|
|
return migrated_file
|
|
|
|
return None
|
|
|
|
|
|
def setup_datasets(args, tokenizer):
|
|
"""Setup reranker datasets"""
|
|
logger.info("Setting up reranker datasets...")
|
|
|
|
# Migrate legacy data if needed
|
|
migrated_file = migrate_legacy_data(args)
|
|
|
|
# Primary reranker data
|
|
train_dataset = create_reranker_dataset(
|
|
data_path=args.reranker_data,
|
|
tokenizer=tokenizer,
|
|
query_max_length=args.query_max_length,
|
|
passage_max_length=args.passage_max_length,
|
|
is_train=True,
|
|
legacy_support=args.legacy_support
|
|
)
|
|
|
|
logger.info(f"Loaded {len(train_dataset)} training examples from {args.reranker_data}")
|
|
logger.info(f"Detected format: {train_dataset.format_type}")
|
|
|
|
# Add migrated legacy data if available
|
|
if migrated_file:
|
|
logger.info(f"Loading migrated legacy data from {migrated_file}")
|
|
migrated_dataset = create_reranker_dataset(
|
|
data_path=migrated_file,
|
|
tokenizer=tokenizer,
|
|
query_max_length=args.query_max_length,
|
|
passage_max_length=args.passage_max_length,
|
|
is_train=True,
|
|
legacy_support=False # Migrated data is in flat format
|
|
)
|
|
logger.info(f"Loaded {len(migrated_dataset)} migrated examples")
|
|
|
|
# Combine datasets
|
|
train_dataset.data.extend(migrated_dataset.data)
|
|
logger.info(f"Total training examples: {len(train_dataset)}")
|
|
|
|
# Evaluation dataset
|
|
eval_dataset = None
|
|
if args.eval_data:
|
|
eval_dataset = create_reranker_dataset(
|
|
data_path=args.eval_data,
|
|
tokenizer=tokenizer,
|
|
query_max_length=args.query_max_length,
|
|
passage_max_length=args.passage_max_length,
|
|
is_train=False,
|
|
legacy_support=args.legacy_support
|
|
)
|
|
logger.info(f"Loaded {len(eval_dataset)} evaluation examples")
|
|
|
|
return train_dataset, eval_dataset
|
|
|
|
|
|
def setup_loss_function(args):
|
|
"""Setup cross-encoder loss function"""
|
|
if args.loss_type == "cross_entropy":
|
|
if args.label_smoothing > 0:
|
|
loss_fn = torch.nn.CrossEntropyLoss(label_smoothing=args.label_smoothing)
|
|
else:
|
|
loss_fn = torch.nn.CrossEntropyLoss()
|
|
elif args.loss_type == "binary_cross_entropy":
|
|
loss_fn = torch.nn.BCEWithLogitsLoss()
|
|
elif args.loss_type == "listwise_ce":
|
|
# Custom listwise cross-entropy loss
|
|
loss_fn = torch.nn.CrossEntropyLoss()
|
|
else:
|
|
raise ValueError(f"Unknown loss type: {args.loss_type}")
|
|
|
|
logger.info(f"Using loss function: {args.loss_type}")
|
|
return loss_fn
|
|
|
|
|
|
def train_epoch(model, dataloader, optimizer, scheduler, loss_fn, device, args):
|
|
"""Train for one epoch"""
|
|
model.train()
|
|
total_loss = 0
|
|
num_batches = 0
|
|
|
|
for batch_idx, batch in enumerate(dataloader):
|
|
# Move batch to device
|
|
input_ids = batch['input_ids'].to(device)
|
|
attention_mask = batch['attention_mask'].to(device)
|
|
labels = batch['labels'].to(device)
|
|
|
|
# Forward pass
|
|
outputs = model(input_ids=input_ids, attention_mask=attention_mask)
|
|
logits = outputs.logits if hasattr(outputs, 'logits') else outputs
|
|
|
|
# Calculate loss
|
|
if args.loss_type == "binary_cross_entropy":
|
|
# For binary classification
|
|
loss = loss_fn(logits.squeeze(), labels.float())
|
|
else:
|
|
# For multi-class classification
|
|
loss = loss_fn(logits, labels)
|
|
|
|
# Backward pass
|
|
loss.backward()
|
|
optimizer.step()
|
|
scheduler.step()
|
|
optimizer.zero_grad()
|
|
|
|
total_loss += loss.item()
|
|
num_batches += 1
|
|
|
|
# Log progress
|
|
if batch_idx % 100 == 0:
|
|
logger.info(f"Batch {batch_idx}/{len(dataloader)}, Loss: {loss.item():.4f}")
|
|
|
|
return total_loss / num_batches
|
|
|
|
|
|
def evaluate(model, dataloader, loss_fn, device, args):
|
|
"""Evaluate the model"""
|
|
model.eval()
|
|
total_loss = 0
|
|
correct = 0
|
|
total = 0
|
|
|
|
with torch.no_grad():
|
|
for batch in dataloader:
|
|
input_ids = batch['input_ids'].to(device)
|
|
attention_mask = batch['attention_mask'].to(device)
|
|
labels = batch['labels'].to(device)
|
|
|
|
# Forward pass
|
|
outputs = model(input_ids=input_ids, attention_mask=attention_mask)
|
|
logits = outputs.logits if hasattr(outputs, 'logits') else outputs
|
|
|
|
# Calculate loss
|
|
if args.loss_type == "binary_cross_entropy":
|
|
loss = loss_fn(logits.squeeze(), labels.float())
|
|
# Calculate accuracy for binary classification
|
|
predictions = (torch.sigmoid(logits.squeeze()) > 0.5).long()
|
|
else:
|
|
loss = loss_fn(logits, labels)
|
|
# Calculate accuracy for multi-class classification
|
|
predictions = torch.argmax(logits, dim=-1)
|
|
|
|
total_loss += loss.item()
|
|
correct += (predictions == labels).sum().item()
|
|
total += labels.size(0)
|
|
|
|
accuracy = correct / total if total > 0 else 0
|
|
return total_loss / len(dataloader), accuracy
|
|
|
|
|
|
def main():
|
|
"""Main training function"""
|
|
args = parse_args()
|
|
|
|
# Setup logging
|
|
setup_logging(
|
|
output_dir=args.output_dir,
|
|
log_level=args.log_level,
|
|
log_to_file=True,
|
|
log_to_console=True
|
|
)
|
|
|
|
logger.info("🚀 Starting BGE-Reranker Training")
|
|
logger.info(f"Arguments: {vars(args)}")
|
|
|
|
# Set random seed
|
|
torch.manual_seed(args.seed)
|
|
if torch.cuda.is_available():
|
|
torch.cuda.manual_seed_all(args.seed)
|
|
|
|
# Setup device
|
|
device = torch.device("cuda" if torch.cuda.is_available() and not args.fp16 else "cpu")
|
|
logger.info(f"Using device: {device}")
|
|
|
|
# Setup model and tokenizer
|
|
model, tokenizer = setup_model_and_tokenizer(args)
|
|
model.to(device)
|
|
|
|
# Setup datasets
|
|
train_dataset, eval_dataset = setup_datasets(args, tokenizer)
|
|
|
|
# Setup data loaders
|
|
train_dataloader = DataLoader(
|
|
train_dataset,
|
|
batch_size=args.per_device_train_batch_size,
|
|
shuffle=True,
|
|
num_workers=args.dataloader_num_workers,
|
|
collate_fn=collate_reranker_batch,
|
|
pin_memory=torch.cuda.is_available()
|
|
)
|
|
|
|
eval_dataloader = None
|
|
if eval_dataset:
|
|
eval_dataloader = DataLoader(
|
|
eval_dataset,
|
|
batch_size=args.per_device_eval_batch_size,
|
|
shuffle=False,
|
|
num_workers=args.dataloader_num_workers,
|
|
collate_fn=collate_reranker_batch,
|
|
pin_memory=torch.cuda.is_available()
|
|
)
|
|
|
|
# Setup optimizer and scheduler
|
|
optimizer = AdamW(
|
|
model.parameters(),
|
|
lr=args.learning_rate,
|
|
weight_decay=args.weight_decay
|
|
)
|
|
|
|
total_steps = len(train_dataloader) * args.num_train_epochs
|
|
warmup_steps = int(total_steps * args.warmup_ratio)
|
|
|
|
scheduler = get_linear_schedule_with_warmup(
|
|
optimizer,
|
|
num_warmup_steps=warmup_steps,
|
|
num_training_steps=total_steps
|
|
)
|
|
|
|
# Setup loss function
|
|
loss_fn = setup_loss_function(args)
|
|
|
|
# Training loop
|
|
logger.info("🎯 Starting training...")
|
|
|
|
best_eval_loss = float('inf')
|
|
|
|
for epoch in range(args.num_train_epochs):
|
|
logger.info(f"Epoch {epoch + 1}/{args.num_train_epochs}")
|
|
|
|
# Train
|
|
train_loss = train_epoch(model, train_dataloader, optimizer, scheduler, loss_fn, device, args)
|
|
logger.info(f"Training loss: {train_loss:.4f}")
|
|
|
|
# Evaluate
|
|
if eval_dataloader:
|
|
eval_loss, eval_accuracy = evaluate(model, eval_dataloader, loss_fn, device, args)
|
|
logger.info(f"Evaluation loss: {eval_loss:.4f}, Accuracy: {eval_accuracy:.4f}")
|
|
|
|
# Save best model
|
|
if eval_loss < best_eval_loss:
|
|
best_eval_loss = eval_loss
|
|
best_model_path = os.path.join(args.output_dir, "best_model")
|
|
os.makedirs(best_model_path, exist_ok=True)
|
|
model.save_pretrained(best_model_path)
|
|
tokenizer.save_pretrained(best_model_path)
|
|
logger.info(f"Saved best model to {best_model_path}")
|
|
|
|
# Save final model
|
|
final_model_path = os.path.join(args.output_dir, "final_model")
|
|
os.makedirs(final_model_path, exist_ok=True)
|
|
model.save_pretrained(final_model_path)
|
|
tokenizer.save_pretrained(final_model_path)
|
|
|
|
logger.info(f"✅ Training completed! Final model saved to {final_model_path}")
|
|
|
|
# Show training summary
|
|
logger.info("📊 Training Summary:")
|
|
logger.info(f" Total examples: {len(train_dataset)}")
|
|
logger.info(f" Format detected: {train_dataset.format_type}")
|
|
logger.info(f" Epochs: {args.num_train_epochs}")
|
|
logger.info(f" Learning rate: {args.learning_rate}")
|
|
logger.info(f" Batch size: {args.per_device_train_batch_size}")
|
|
logger.info(f" Loss function: {args.loss_type}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|