bge_finetune/scripts/compare_retriever.py
2025-07-22 16:55:25 +08:00

144 lines
7.4 KiB
Python

import os
import time
import argparse
import logging
import torch
import psutil
import tracemalloc
import numpy as np
from utils.config_loader import get_config
from models.bge_m3 import BGEM3Model
from data.dataset import BGEM3Dataset
from transformers import AutoTokenizer
from evaluation.metrics import compute_retrieval_metrics
from torch.utils.data import DataLoader
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("compare_retriever")
def parse_args():
"""Parse command line arguments for retriever comparison."""
parser = argparse.ArgumentParser(description="Compare fine-tuned and baseline retriever models.")
parser.add_argument('--finetuned_model_path', type=str, required=True, help='Path to fine-tuned retriever model')
parser.add_argument('--baseline_model_path', type=str, required=True, help='Path to baseline retriever model')
parser.add_argument('--data_path', type=str, required=True, help='Path to evaluation data (JSONL)')
parser.add_argument('--batch_size', type=int, default=16, help='Batch size for inference')
parser.add_argument('--max_samples', type=int, default=1000, help='Max samples to benchmark (for speed)')
parser.add_argument('--device', type=str, default='cuda', help='Device to use (cuda, cpu, npu)')
parser.add_argument('--output', type=str, default='compare_retriever_results.txt', help='File to save comparison results')
parser.add_argument('--k_values', type=int, nargs='+', default=[1, 5, 10, 20, 100], help='K values for metrics@k')
return parser.parse_args()
def measure_memory():
"""Measure the current process memory usage in MB."""
process = psutil.Process(os.getpid())
return process.memory_info().rss / (1024 * 1024) # MB
def run_retriever(model_path, data_path, batch_size, max_samples, device):
"""
Run inference on a retriever model and measure throughput, latency, and memory.
Args:
model_path (str): Path to the retriever model.
data_path (str): Path to the evaluation data.
batch_size (int): Batch size for inference.
max_samples (int): Maximum number of samples to process.
device (str): Device to use (cuda, cpu, npu).
Returns:
tuple: throughput (samples/sec), latency (ms/sample), peak memory (MB),
query_embeds (np.ndarray), passage_embeds (np.ndarray), labels (np.ndarray).
"""
model = BGEM3Model(model_name_or_path=model_path, device=device)
tokenizer = model.tokenizer if hasattr(model, 'tokenizer') and model.tokenizer else AutoTokenizer.from_pretrained(model_path)
dataset = BGEM3Dataset(data_path, tokenizer, is_train=False)
dataloader = DataLoader(dataset, batch_size=batch_size)
model.eval()
model.to(device)
n_samples = 0
times = []
query_embeds = []
passage_embeds = []
labels = []
tracemalloc.start()
with torch.no_grad():
for batch in dataloader:
if n_samples >= max_samples:
break
start = time.time()
out = model(
batch['query_input_ids'].to(device),
batch['query_attention_mask'].to(device),
return_dense=True, return_sparse=False, return_colbert=False
)
torch.cuda.synchronize() if device.startswith('cuda') else None
end = time.time()
times.append(end - start)
n = batch['query_input_ids'].shape[0]
n_samples += n
# For metrics: collect embeddings and labels
query_embeds.append(out['query_embeds'].cpu().numpy())
passage_embeds.append(out['passage_embeds'].cpu().numpy())
labels.append(batch['labels'].cpu().numpy())
current, peak = tracemalloc.get_traced_memory()
tracemalloc.stop()
throughput = n_samples / sum(times)
latency = np.mean(times) / batch_size
mem_mb = peak / (1024 * 1024)
# Prepare for metrics
query_embeds = np.concatenate(query_embeds, axis=0)
passage_embeds = np.concatenate(passage_embeds, axis=0)
labels = np.concatenate(labels, axis=0)
return throughput, latency, mem_mb, query_embeds, passage_embeds, labels
def main():
"""Run retriever comparison with robust error handling and config-driven defaults."""
import argparse
parser = argparse.ArgumentParser(description="Compare retriever models.")
parser.add_argument('--device', type=str, default=None, help='Device to use (cuda, npu, cpu)')
parser.add_argument('--finetuned_model_path', type=str, required=True, help='Path to fine-tuned retriever model')
parser.add_argument('--baseline_model_path', type=str, required=True, help='Path to baseline retriever model')
parser.add_argument('--data_path', type=str, required=True, help='Path to evaluation data (JSONL)')
parser.add_argument('--batch_size', type=int, default=16, help='Batch size for inference')
parser.add_argument('--max_samples', type=int, default=1000, help='Max samples to benchmark (for speed)')
parser.add_argument('--output', type=str, default='compare_retriever_results.txt', help='File to save comparison results')
parser.add_argument('--k_values', type=int, nargs='+', default=[1, 5, 10, 20, 100], help='K values for metrics@k')
args = parser.parse_args()
# Load config and set defaults
from utils.config_loader import get_config
config = get_config()
device = args.device or config.get('hardware', {}).get('device', 'cuda') # type: ignore[attr-defined]
try:
# Fine-tuned model
logger.info("Running fine-tuned retriever...")
ft_throughput, ft_latency, ft_mem, ft_qe, ft_pe, ft_labels = run_retriever(
args.finetuned_model_path, args.data_path, args.batch_size, args.max_samples, device)
ft_metrics = compute_retrieval_metrics(ft_qe, ft_pe, ft_labels, k_values=args.k_values)
# Baseline model
logger.info("Running baseline retriever...")
bl_throughput, bl_latency, bl_mem, bl_qe, bl_pe, bl_labels = run_retriever(
args.baseline_model_path, args.data_path, args.batch_size, args.max_samples, device)
bl_metrics = compute_retrieval_metrics(bl_qe, bl_pe, bl_labels, k_values=args.k_values)
# Output comparison
with open(args.output, 'w') as f:
f.write(f"Metric\tBaseline\tFine-tuned\tDelta\n")
for k in args.k_values:
for metric in [f'recall@{k}', f'precision@{k}', f'map@{k}', f'mrr@{k}', f'ndcg@{k}']:
bl = bl_metrics.get(metric, 0)
ft = ft_metrics.get(metric, 0)
delta = ft - bl
f.write(f"{metric}\t{bl:.4f}\t{ft:.4f}\t{delta:+.4f}\n")
for metric in ['map', 'mrr']:
bl = bl_metrics.get(metric, 0)
ft = ft_metrics.get(metric, 0)
delta = ft - bl
f.write(f"{metric}\t{bl:.4f}\t{ft:.4f}\t{delta:+.4f}\n")
f.write(f"Throughput (samples/sec)\t{bl_throughput:.2f}\t{ft_throughput:.2f}\t{ft_throughput-bl_throughput:+.2f}\n")
f.write(f"Latency (ms/sample)\t{bl_latency*1000:.2f}\t{ft_latency*1000:.2f}\t{(ft_latency-bl_latency)*1000:+.2f}\n")
f.write(f"Peak memory (MB)\t{bl_mem:.2f}\t{ft_mem:.2f}\t{ft_mem-bl_mem:+.2f}\n")
logger.info(f"Comparison results saved to {args.output}")
print(f"\nComparison complete. See {args.output} for details.")
except Exception as e:
logging.error(f"Retriever comparison failed: {e}")
raise
if __name__ == '__main__':
main()