init commit
This commit is contained in:
398
scripts/train_reranker.py
Normal file
398
scripts/train_reranker.py
Normal file
@@ -0,0 +1,398 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
BGE-Reranker Cross-Encoder Model Training Script
|
||||
Uses the refactored dataset pipeline with flat pairs schema
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import argparse
|
||||
import logging
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
from transformers import AutoTokenizer, get_linear_schedule_with_warmup
|
||||
from transformers.optimization import AdamW
|
||||
import torch.nn.functional as F
|
||||
|
||||
# Add parent directory to path
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from data.dataset import create_reranker_dataset, collate_reranker_batch, migrate_nested_to_flat
|
||||
from models.bge_reranker import BGERerankerModel
|
||||
from training.reranker_trainer import BGERerankerTrainer
|
||||
from utils.logging import setup_logging, get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def parse_args():
|
||||
"""Parse command line arguments for reranker training"""
|
||||
parser = argparse.ArgumentParser(description="Train BGE-Reranker Cross-Encoder Model")
|
||||
|
||||
# Model arguments
|
||||
parser.add_argument("--model_name_or_path", type=str, default="BAAI/bge-reranker-base",
|
||||
help="Path to pretrained BGE-Reranker model")
|
||||
parser.add_argument("--cache_dir", type=str, default="./cache/data",
|
||||
help="Directory to cache downloaded models")
|
||||
|
||||
# Data arguments
|
||||
parser.add_argument("--reranker_data", type=str, required=True,
|
||||
help="Path to reranker training data (reranker_data.jsonl)")
|
||||
parser.add_argument("--eval_data", type=str, default=None,
|
||||
help="Path to evaluation data")
|
||||
parser.add_argument("--legacy_data", type=str, default=None,
|
||||
help="Path to legacy nested data (will be migrated)")
|
||||
|
||||
# Training arguments
|
||||
parser.add_argument("--output_dir", type=str, default="./output/bge-reranker",
|
||||
help="Directory to save the trained model")
|
||||
parser.add_argument("--num_train_epochs", type=int, default=3,
|
||||
help="Number of training epochs")
|
||||
parser.add_argument("--per_device_train_batch_size", type=int, default=8,
|
||||
help="Batch size per GPU/CPU for training")
|
||||
parser.add_argument("--per_device_eval_batch_size", type=int, default=16,
|
||||
help="Batch size per GPU/CPU for evaluation")
|
||||
parser.add_argument("--learning_rate", type=float, default=6e-5,
|
||||
help="Learning rate")
|
||||
parser.add_argument("--weight_decay", type=float, default=0.01,
|
||||
help="Weight decay")
|
||||
parser.add_argument("--warmup_ratio", type=float, default=0.1,
|
||||
help="Warmup ratio")
|
||||
|
||||
# Model specific arguments
|
||||
parser.add_argument("--query_max_length", type=int, default=64,
|
||||
help="Maximum query length")
|
||||
parser.add_argument("--passage_max_length", type=int, default=448,
|
||||
help="Maximum passage length")
|
||||
parser.add_argument("--loss_type", type=str, default="cross_entropy",
|
||||
choices=["cross_entropy", "binary_cross_entropy", "listwise_ce"],
|
||||
help="Loss function type")
|
||||
|
||||
# Enhanced features
|
||||
parser.add_argument("--use_score_weighting", action="store_true",
|
||||
help="Use score weighting in loss calculation")
|
||||
parser.add_argument("--label_smoothing", type=float, default=0.0,
|
||||
help="Label smoothing for cross entropy loss")
|
||||
|
||||
# System arguments
|
||||
parser.add_argument("--fp16", action="store_true",
|
||||
help="Use 16-bit floating point precision")
|
||||
parser.add_argument("--dataloader_num_workers", type=int, default=4,
|
||||
help="Number of workers for data loading")
|
||||
parser.add_argument("--seed", type=int, default=42,
|
||||
help="Random seed")
|
||||
parser.add_argument("--log_level", type=str, default="INFO",
|
||||
choices=["DEBUG", "INFO", "WARNING", "ERROR"],
|
||||
help="Logging level")
|
||||
|
||||
# Backward compatibility
|
||||
parser.add_argument("--legacy_support", action="store_true", default=True,
|
||||
help="Enable legacy nested format support")
|
||||
parser.add_argument("--migrate_legacy", action="store_true",
|
||||
help="Migrate legacy nested format to flat format")
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def setup_model_and_tokenizer(args):
|
||||
"""Setup BGE-Reranker model and tokenizer"""
|
||||
logger.info(f"Loading BGE-Reranker model from {args.model_name_or_path}")
|
||||
|
||||
# Load tokenizer
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
args.model_name_or_path,
|
||||
cache_dir=args.cache_dir,
|
||||
trust_remote_code=True
|
||||
)
|
||||
|
||||
# Load model
|
||||
model = BGERerankerModel(
|
||||
model_name_or_path=args.model_name_or_path,
|
||||
cache_dir=args.cache_dir
|
||||
)
|
||||
|
||||
return model, tokenizer
|
||||
|
||||
|
||||
def migrate_legacy_data(args):
|
||||
"""Migrate legacy nested data to flat format if needed"""
|
||||
if args.legacy_data and args.migrate_legacy:
|
||||
logger.info("Migrating legacy nested data to flat format...")
|
||||
|
||||
# Create output file name
|
||||
legacy_base = os.path.splitext(args.legacy_data)[0]
|
||||
migrated_file = f"{legacy_base}_migrated_flat.jsonl"
|
||||
|
||||
# Migrate the data
|
||||
migrate_nested_to_flat(args.legacy_data, migrated_file)
|
||||
|
||||
logger.info(f"✅ Migrated legacy data to {migrated_file}")
|
||||
return migrated_file
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def setup_datasets(args, tokenizer):
|
||||
"""Setup reranker datasets"""
|
||||
logger.info("Setting up reranker datasets...")
|
||||
|
||||
# Migrate legacy data if needed
|
||||
migrated_file = migrate_legacy_data(args)
|
||||
|
||||
# Primary reranker data
|
||||
train_dataset = create_reranker_dataset(
|
||||
data_path=args.reranker_data,
|
||||
tokenizer=tokenizer,
|
||||
query_max_length=args.query_max_length,
|
||||
passage_max_length=args.passage_max_length,
|
||||
is_train=True,
|
||||
legacy_support=args.legacy_support
|
||||
)
|
||||
|
||||
logger.info(f"Loaded {len(train_dataset)} training examples from {args.reranker_data}")
|
||||
logger.info(f"Detected format: {train_dataset.format_type}")
|
||||
|
||||
# Add migrated legacy data if available
|
||||
if migrated_file:
|
||||
logger.info(f"Loading migrated legacy data from {migrated_file}")
|
||||
migrated_dataset = create_reranker_dataset(
|
||||
data_path=migrated_file,
|
||||
tokenizer=tokenizer,
|
||||
query_max_length=args.query_max_length,
|
||||
passage_max_length=args.passage_max_length,
|
||||
is_train=True,
|
||||
legacy_support=False # Migrated data is in flat format
|
||||
)
|
||||
logger.info(f"Loaded {len(migrated_dataset)} migrated examples")
|
||||
|
||||
# Combine datasets
|
||||
train_dataset.data.extend(migrated_dataset.data)
|
||||
logger.info(f"Total training examples: {len(train_dataset)}")
|
||||
|
||||
# Evaluation dataset
|
||||
eval_dataset = None
|
||||
if args.eval_data:
|
||||
eval_dataset = create_reranker_dataset(
|
||||
data_path=args.eval_data,
|
||||
tokenizer=tokenizer,
|
||||
query_max_length=args.query_max_length,
|
||||
passage_max_length=args.passage_max_length,
|
||||
is_train=False,
|
||||
legacy_support=args.legacy_support
|
||||
)
|
||||
logger.info(f"Loaded {len(eval_dataset)} evaluation examples")
|
||||
|
||||
return train_dataset, eval_dataset
|
||||
|
||||
|
||||
def setup_loss_function(args):
|
||||
"""Setup cross-encoder loss function"""
|
||||
if args.loss_type == "cross_entropy":
|
||||
if args.label_smoothing > 0:
|
||||
loss_fn = torch.nn.CrossEntropyLoss(label_smoothing=args.label_smoothing)
|
||||
else:
|
||||
loss_fn = torch.nn.CrossEntropyLoss()
|
||||
elif args.loss_type == "binary_cross_entropy":
|
||||
loss_fn = torch.nn.BCEWithLogitsLoss()
|
||||
elif args.loss_type == "listwise_ce":
|
||||
# Custom listwise cross-entropy loss
|
||||
loss_fn = torch.nn.CrossEntropyLoss()
|
||||
else:
|
||||
raise ValueError(f"Unknown loss type: {args.loss_type}")
|
||||
|
||||
logger.info(f"Using loss function: {args.loss_type}")
|
||||
return loss_fn
|
||||
|
||||
|
||||
def train_epoch(model, dataloader, optimizer, scheduler, loss_fn, device, args):
|
||||
"""Train for one epoch"""
|
||||
model.train()
|
||||
total_loss = 0
|
||||
num_batches = 0
|
||||
|
||||
for batch_idx, batch in enumerate(dataloader):
|
||||
# Move batch to device
|
||||
input_ids = batch['input_ids'].to(device)
|
||||
attention_mask = batch['attention_mask'].to(device)
|
||||
labels = batch['labels'].to(device)
|
||||
|
||||
# Forward pass
|
||||
outputs = model(input_ids=input_ids, attention_mask=attention_mask)
|
||||
logits = outputs.logits if hasattr(outputs, 'logits') else outputs
|
||||
|
||||
# Calculate loss
|
||||
if args.loss_type == "binary_cross_entropy":
|
||||
# For binary classification
|
||||
loss = loss_fn(logits.squeeze(), labels.float())
|
||||
else:
|
||||
# For multi-class classification
|
||||
loss = loss_fn(logits, labels)
|
||||
|
||||
# Backward pass
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
scheduler.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
total_loss += loss.item()
|
||||
num_batches += 1
|
||||
|
||||
# Log progress
|
||||
if batch_idx % 100 == 0:
|
||||
logger.info(f"Batch {batch_idx}/{len(dataloader)}, Loss: {loss.item():.4f}")
|
||||
|
||||
return total_loss / num_batches
|
||||
|
||||
|
||||
def evaluate(model, dataloader, loss_fn, device, args):
|
||||
"""Evaluate the model"""
|
||||
model.eval()
|
||||
total_loss = 0
|
||||
correct = 0
|
||||
total = 0
|
||||
|
||||
with torch.no_grad():
|
||||
for batch in dataloader:
|
||||
input_ids = batch['input_ids'].to(device)
|
||||
attention_mask = batch['attention_mask'].to(device)
|
||||
labels = batch['labels'].to(device)
|
||||
|
||||
# Forward pass
|
||||
outputs = model(input_ids=input_ids, attention_mask=attention_mask)
|
||||
logits = outputs.logits if hasattr(outputs, 'logits') else outputs
|
||||
|
||||
# Calculate loss
|
||||
if args.loss_type == "binary_cross_entropy":
|
||||
loss = loss_fn(logits.squeeze(), labels.float())
|
||||
# Calculate accuracy for binary classification
|
||||
predictions = (torch.sigmoid(logits.squeeze()) > 0.5).long()
|
||||
else:
|
||||
loss = loss_fn(logits, labels)
|
||||
# Calculate accuracy for multi-class classification
|
||||
predictions = torch.argmax(logits, dim=-1)
|
||||
|
||||
total_loss += loss.item()
|
||||
correct += (predictions == labels).sum().item()
|
||||
total += labels.size(0)
|
||||
|
||||
accuracy = correct / total if total > 0 else 0
|
||||
return total_loss / len(dataloader), accuracy
|
||||
|
||||
|
||||
def main():
|
||||
"""Main training function"""
|
||||
args = parse_args()
|
||||
|
||||
# Setup logging
|
||||
setup_logging(
|
||||
output_dir=args.output_dir,
|
||||
log_level=args.log_level,
|
||||
log_to_file=True,
|
||||
log_to_console=True
|
||||
)
|
||||
|
||||
logger.info("🚀 Starting BGE-Reranker Training")
|
||||
logger.info(f"Arguments: {vars(args)}")
|
||||
|
||||
# Set random seed
|
||||
torch.manual_seed(args.seed)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed_all(args.seed)
|
||||
|
||||
# Setup device
|
||||
device = torch.device("cuda" if torch.cuda.is_available() and not args.fp16 else "cpu")
|
||||
logger.info(f"Using device: {device}")
|
||||
|
||||
# Setup model and tokenizer
|
||||
model, tokenizer = setup_model_and_tokenizer(args)
|
||||
model.to(device)
|
||||
|
||||
# Setup datasets
|
||||
train_dataset, eval_dataset = setup_datasets(args, tokenizer)
|
||||
|
||||
# Setup data loaders
|
||||
train_dataloader = DataLoader(
|
||||
train_dataset,
|
||||
batch_size=args.per_device_train_batch_size,
|
||||
shuffle=True,
|
||||
num_workers=args.dataloader_num_workers,
|
||||
collate_fn=collate_reranker_batch,
|
||||
pin_memory=torch.cuda.is_available()
|
||||
)
|
||||
|
||||
eval_dataloader = None
|
||||
if eval_dataset:
|
||||
eval_dataloader = DataLoader(
|
||||
eval_dataset,
|
||||
batch_size=args.per_device_eval_batch_size,
|
||||
shuffle=False,
|
||||
num_workers=args.dataloader_num_workers,
|
||||
collate_fn=collate_reranker_batch,
|
||||
pin_memory=torch.cuda.is_available()
|
||||
)
|
||||
|
||||
# Setup optimizer and scheduler
|
||||
optimizer = AdamW(
|
||||
model.parameters(),
|
||||
lr=args.learning_rate,
|
||||
weight_decay=args.weight_decay
|
||||
)
|
||||
|
||||
total_steps = len(train_dataloader) * args.num_train_epochs
|
||||
warmup_steps = int(total_steps * args.warmup_ratio)
|
||||
|
||||
scheduler = get_linear_schedule_with_warmup(
|
||||
optimizer,
|
||||
num_warmup_steps=warmup_steps,
|
||||
num_training_steps=total_steps
|
||||
)
|
||||
|
||||
# Setup loss function
|
||||
loss_fn = setup_loss_function(args)
|
||||
|
||||
# Training loop
|
||||
logger.info("🎯 Starting training...")
|
||||
|
||||
best_eval_loss = float('inf')
|
||||
|
||||
for epoch in range(args.num_train_epochs):
|
||||
logger.info(f"Epoch {epoch + 1}/{args.num_train_epochs}")
|
||||
|
||||
# Train
|
||||
train_loss = train_epoch(model, train_dataloader, optimizer, scheduler, loss_fn, device, args)
|
||||
logger.info(f"Training loss: {train_loss:.4f}")
|
||||
|
||||
# Evaluate
|
||||
if eval_dataloader:
|
||||
eval_loss, eval_accuracy = evaluate(model, eval_dataloader, loss_fn, device, args)
|
||||
logger.info(f"Evaluation loss: {eval_loss:.4f}, Accuracy: {eval_accuracy:.4f}")
|
||||
|
||||
# Save best model
|
||||
if eval_loss < best_eval_loss:
|
||||
best_eval_loss = eval_loss
|
||||
best_model_path = os.path.join(args.output_dir, "best_model")
|
||||
os.makedirs(best_model_path, exist_ok=True)
|
||||
model.save_pretrained(best_model_path)
|
||||
tokenizer.save_pretrained(best_model_path)
|
||||
logger.info(f"Saved best model to {best_model_path}")
|
||||
|
||||
# Save final model
|
||||
final_model_path = os.path.join(args.output_dir, "final_model")
|
||||
os.makedirs(final_model_path, exist_ok=True)
|
||||
model.save_pretrained(final_model_path)
|
||||
tokenizer.save_pretrained(final_model_path)
|
||||
|
||||
logger.info(f"✅ Training completed! Final model saved to {final_model_path}")
|
||||
|
||||
# Show training summary
|
||||
logger.info("📊 Training Summary:")
|
||||
logger.info(f" Total examples: {len(train_dataset)}")
|
||||
logger.info(f" Format detected: {train_dataset.format_type}")
|
||||
logger.info(f" Epochs: {args.num_train_epochs}")
|
||||
logger.info(f" Learning rate: {args.learning_rate}")
|
||||
logger.info(f" Batch size: {args.per_device_train_batch_size}")
|
||||
logger.info(f" Loss function: {args.loss_type}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user