bge_finetune/scripts/train_reranker.py
2025-07-22 16:55:25 +08:00

399 lines
14 KiB
Python

#!/usr/bin/env python3
"""
BGE-Reranker Cross-Encoder Model Training Script
Uses the refactored dataset pipeline with flat pairs schema
"""
import os
import sys
import argparse
import logging
import torch
from torch.utils.data import DataLoader
from transformers import AutoTokenizer, get_linear_schedule_with_warmup
from transformers.optimization import AdamW
import torch.nn.functional as F
# Add parent directory to path
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from data.dataset import create_reranker_dataset, collate_reranker_batch, migrate_nested_to_flat
from models.bge_reranker import BGERerankerModel
from training.reranker_trainer import BGERerankerTrainer
from utils.logging import setup_logging, get_logger
logger = get_logger(__name__)
def parse_args():
"""Parse command line arguments for reranker training"""
parser = argparse.ArgumentParser(description="Train BGE-Reranker Cross-Encoder Model")
# Model arguments
parser.add_argument("--model_name_or_path", type=str, default="BAAI/bge-reranker-base",
help="Path to pretrained BGE-Reranker model")
parser.add_argument("--cache_dir", type=str, default="./cache/data",
help="Directory to cache downloaded models")
# Data arguments
parser.add_argument("--reranker_data", type=str, required=True,
help="Path to reranker training data (reranker_data.jsonl)")
parser.add_argument("--eval_data", type=str, default=None,
help="Path to evaluation data")
parser.add_argument("--legacy_data", type=str, default=None,
help="Path to legacy nested data (will be migrated)")
# Training arguments
parser.add_argument("--output_dir", type=str, default="./output/bge-reranker",
help="Directory to save the trained model")
parser.add_argument("--num_train_epochs", type=int, default=3,
help="Number of training epochs")
parser.add_argument("--per_device_train_batch_size", type=int, default=8,
help="Batch size per GPU/CPU for training")
parser.add_argument("--per_device_eval_batch_size", type=int, default=16,
help="Batch size per GPU/CPU for evaluation")
parser.add_argument("--learning_rate", type=float, default=6e-5,
help="Learning rate")
parser.add_argument("--weight_decay", type=float, default=0.01,
help="Weight decay")
parser.add_argument("--warmup_ratio", type=float, default=0.1,
help="Warmup ratio")
# Model specific arguments
parser.add_argument("--query_max_length", type=int, default=64,
help="Maximum query length")
parser.add_argument("--passage_max_length", type=int, default=448,
help="Maximum passage length")
parser.add_argument("--loss_type", type=str, default="cross_entropy",
choices=["cross_entropy", "binary_cross_entropy", "listwise_ce"],
help="Loss function type")
# Enhanced features
parser.add_argument("--use_score_weighting", action="store_true",
help="Use score weighting in loss calculation")
parser.add_argument("--label_smoothing", type=float, default=0.0,
help="Label smoothing for cross entropy loss")
# System arguments
parser.add_argument("--fp16", action="store_true",
help="Use 16-bit floating point precision")
parser.add_argument("--dataloader_num_workers", type=int, default=4,
help="Number of workers for data loading")
parser.add_argument("--seed", type=int, default=42,
help="Random seed")
parser.add_argument("--log_level", type=str, default="INFO",
choices=["DEBUG", "INFO", "WARNING", "ERROR"],
help="Logging level")
# Backward compatibility
parser.add_argument("--legacy_support", action="store_true", default=True,
help="Enable legacy nested format support")
parser.add_argument("--migrate_legacy", action="store_true",
help="Migrate legacy nested format to flat format")
return parser.parse_args()
def setup_model_and_tokenizer(args):
"""Setup BGE-Reranker model and tokenizer"""
logger.info(f"Loading BGE-Reranker model from {args.model_name_or_path}")
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(
args.model_name_or_path,
cache_dir=args.cache_dir,
trust_remote_code=True
)
# Load model
model = BGERerankerModel(
model_name_or_path=args.model_name_or_path,
cache_dir=args.cache_dir
)
return model, tokenizer
def migrate_legacy_data(args):
"""Migrate legacy nested data to flat format if needed"""
if args.legacy_data and args.migrate_legacy:
logger.info("Migrating legacy nested data to flat format...")
# Create output file name
legacy_base = os.path.splitext(args.legacy_data)[0]
migrated_file = f"{legacy_base}_migrated_flat.jsonl"
# Migrate the data
migrate_nested_to_flat(args.legacy_data, migrated_file)
logger.info(f"✅ Migrated legacy data to {migrated_file}")
return migrated_file
return None
def setup_datasets(args, tokenizer):
"""Setup reranker datasets"""
logger.info("Setting up reranker datasets...")
# Migrate legacy data if needed
migrated_file = migrate_legacy_data(args)
# Primary reranker data
train_dataset = create_reranker_dataset(
data_path=args.reranker_data,
tokenizer=tokenizer,
query_max_length=args.query_max_length,
passage_max_length=args.passage_max_length,
is_train=True,
legacy_support=args.legacy_support
)
logger.info(f"Loaded {len(train_dataset)} training examples from {args.reranker_data}")
logger.info(f"Detected format: {train_dataset.format_type}")
# Add migrated legacy data if available
if migrated_file:
logger.info(f"Loading migrated legacy data from {migrated_file}")
migrated_dataset = create_reranker_dataset(
data_path=migrated_file,
tokenizer=tokenizer,
query_max_length=args.query_max_length,
passage_max_length=args.passage_max_length,
is_train=True,
legacy_support=False # Migrated data is in flat format
)
logger.info(f"Loaded {len(migrated_dataset)} migrated examples")
# Combine datasets
train_dataset.data.extend(migrated_dataset.data)
logger.info(f"Total training examples: {len(train_dataset)}")
# Evaluation dataset
eval_dataset = None
if args.eval_data:
eval_dataset = create_reranker_dataset(
data_path=args.eval_data,
tokenizer=tokenizer,
query_max_length=args.query_max_length,
passage_max_length=args.passage_max_length,
is_train=False,
legacy_support=args.legacy_support
)
logger.info(f"Loaded {len(eval_dataset)} evaluation examples")
return train_dataset, eval_dataset
def setup_loss_function(args):
"""Setup cross-encoder loss function"""
if args.loss_type == "cross_entropy":
if args.label_smoothing > 0:
loss_fn = torch.nn.CrossEntropyLoss(label_smoothing=args.label_smoothing)
else:
loss_fn = torch.nn.CrossEntropyLoss()
elif args.loss_type == "binary_cross_entropy":
loss_fn = torch.nn.BCEWithLogitsLoss()
elif args.loss_type == "listwise_ce":
# Custom listwise cross-entropy loss
loss_fn = torch.nn.CrossEntropyLoss()
else:
raise ValueError(f"Unknown loss type: {args.loss_type}")
logger.info(f"Using loss function: {args.loss_type}")
return loss_fn
def train_epoch(model, dataloader, optimizer, scheduler, loss_fn, device, args):
"""Train for one epoch"""
model.train()
total_loss = 0
num_batches = 0
for batch_idx, batch in enumerate(dataloader):
# Move batch to device
input_ids = batch['input_ids'].to(device)
attention_mask = batch['attention_mask'].to(device)
labels = batch['labels'].to(device)
# Forward pass
outputs = model(input_ids=input_ids, attention_mask=attention_mask)
logits = outputs.logits if hasattr(outputs, 'logits') else outputs
# Calculate loss
if args.loss_type == "binary_cross_entropy":
# For binary classification
loss = loss_fn(logits.squeeze(), labels.float())
else:
# For multi-class classification
loss = loss_fn(logits, labels)
# Backward pass
loss.backward()
optimizer.step()
scheduler.step()
optimizer.zero_grad()
total_loss += loss.item()
num_batches += 1
# Log progress
if batch_idx % 100 == 0:
logger.info(f"Batch {batch_idx}/{len(dataloader)}, Loss: {loss.item():.4f}")
return total_loss / num_batches
def evaluate(model, dataloader, loss_fn, device, args):
"""Evaluate the model"""
model.eval()
total_loss = 0
correct = 0
total = 0
with torch.no_grad():
for batch in dataloader:
input_ids = batch['input_ids'].to(device)
attention_mask = batch['attention_mask'].to(device)
labels = batch['labels'].to(device)
# Forward pass
outputs = model(input_ids=input_ids, attention_mask=attention_mask)
logits = outputs.logits if hasattr(outputs, 'logits') else outputs
# Calculate loss
if args.loss_type == "binary_cross_entropy":
loss = loss_fn(logits.squeeze(), labels.float())
# Calculate accuracy for binary classification
predictions = (torch.sigmoid(logits.squeeze()) > 0.5).long()
else:
loss = loss_fn(logits, labels)
# Calculate accuracy for multi-class classification
predictions = torch.argmax(logits, dim=-1)
total_loss += loss.item()
correct += (predictions == labels).sum().item()
total += labels.size(0)
accuracy = correct / total if total > 0 else 0
return total_loss / len(dataloader), accuracy
def main():
"""Main training function"""
args = parse_args()
# Setup logging
setup_logging(
output_dir=args.output_dir,
log_level=args.log_level,
log_to_file=True,
log_to_console=True
)
logger.info("🚀 Starting BGE-Reranker Training")
logger.info(f"Arguments: {vars(args)}")
# Set random seed
torch.manual_seed(args.seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(args.seed)
# Setup device
device = torch.device("cuda" if torch.cuda.is_available() and not args.fp16 else "cpu")
logger.info(f"Using device: {device}")
# Setup model and tokenizer
model, tokenizer = setup_model_and_tokenizer(args)
model.to(device)
# Setup datasets
train_dataset, eval_dataset = setup_datasets(args, tokenizer)
# Setup data loaders
train_dataloader = DataLoader(
train_dataset,
batch_size=args.per_device_train_batch_size,
shuffle=True,
num_workers=args.dataloader_num_workers,
collate_fn=collate_reranker_batch,
pin_memory=torch.cuda.is_available()
)
eval_dataloader = None
if eval_dataset:
eval_dataloader = DataLoader(
eval_dataset,
batch_size=args.per_device_eval_batch_size,
shuffle=False,
num_workers=args.dataloader_num_workers,
collate_fn=collate_reranker_batch,
pin_memory=torch.cuda.is_available()
)
# Setup optimizer and scheduler
optimizer = AdamW(
model.parameters(),
lr=args.learning_rate,
weight_decay=args.weight_decay
)
total_steps = len(train_dataloader) * args.num_train_epochs
warmup_steps = int(total_steps * args.warmup_ratio)
scheduler = get_linear_schedule_with_warmup(
optimizer,
num_warmup_steps=warmup_steps,
num_training_steps=total_steps
)
# Setup loss function
loss_fn = setup_loss_function(args)
# Training loop
logger.info("🎯 Starting training...")
best_eval_loss = float('inf')
for epoch in range(args.num_train_epochs):
logger.info(f"Epoch {epoch + 1}/{args.num_train_epochs}")
# Train
train_loss = train_epoch(model, train_dataloader, optimizer, scheduler, loss_fn, device, args)
logger.info(f"Training loss: {train_loss:.4f}")
# Evaluate
if eval_dataloader:
eval_loss, eval_accuracy = evaluate(model, eval_dataloader, loss_fn, device, args)
logger.info(f"Evaluation loss: {eval_loss:.4f}, Accuracy: {eval_accuracy:.4f}")
# Save best model
if eval_loss < best_eval_loss:
best_eval_loss = eval_loss
best_model_path = os.path.join(args.output_dir, "best_model")
os.makedirs(best_model_path, exist_ok=True)
model.save_pretrained(best_model_path)
tokenizer.save_pretrained(best_model_path)
logger.info(f"Saved best model to {best_model_path}")
# Save final model
final_model_path = os.path.join(args.output_dir, "final_model")
os.makedirs(final_model_path, exist_ok=True)
model.save_pretrained(final_model_path)
tokenizer.save_pretrained(final_model_path)
logger.info(f"✅ Training completed! Final model saved to {final_model_path}")
# Show training summary
logger.info("📊 Training Summary:")
logger.info(f" Total examples: {len(train_dataset)}")
logger.info(f" Format detected: {train_dataset.format_type}")
logger.info(f" Epochs: {args.num_train_epochs}")
logger.info(f" Learning rate: {args.learning_rate}")
logger.info(f" Batch size: {args.per_device_train_batch_size}")
logger.info(f" Loss function: {args.loss_type}")
if __name__ == "__main__":
main()