import os import time import argparse import logging import torch import psutil import tracemalloc import numpy as np from utils.config_loader import get_config from models.bge_reranker import BGERerankerModel from data.dataset import BGERerankerDataset from transformers import AutoTokenizer from evaluation.metrics import compute_reranker_metrics from torch.utils.data import DataLoader logging.basicConfig(level=logging.INFO) logger = logging.getLogger("compare_reranker") def parse_args(): """Parse command-line arguments for reranker comparison.""" parser = argparse.ArgumentParser(description="Compare fine-tuned and baseline reranker models.") parser.add_argument('--finetuned_model_path', type=str, required=True, help='Path to fine-tuned reranker model') parser.add_argument('--baseline_model_path', type=str, required=True, help='Path to baseline reranker model') parser.add_argument('--data_path', type=str, required=True, help='Path to evaluation data (JSONL)') parser.add_argument('--batch_size', type=int, default=16, help='Batch size for inference') parser.add_argument('--max_samples', type=int, default=1000, help='Max samples to benchmark (for speed)') parser.add_argument('--device', type=str, default='cuda', help='Device to use (cuda, cpu, npu)') parser.add_argument('--output', type=str, default='compare_reranker_results.txt', help='File to save comparison results') parser.add_argument('--k_values', type=int, nargs='+', default=[1, 5, 10], help='K values for metrics@k') return parser.parse_args() def measure_memory(): """Measure the current memory usage of the process.""" process = psutil.Process(os.getpid()) return process.memory_info().rss / (1024 * 1024) # MB def run_reranker(model_path, data_path, batch_size, max_samples, device): """ Run inference on a reranker model and measure throughput, latency, and memory. Args: model_path (str): Path to the reranker model. data_path (str): Path to the evaluation data. batch_size (int): Batch size for inference. max_samples (int): Maximum number of samples to process. device (str): Device to use (cuda, cpu, npu). Returns: tuple: (throughput, latency, mem_mb, all_scores, all_labels) """ model = BGERerankerModel(model_name_or_path=model_path, device=device) tokenizer = model.tokenizer if hasattr(model, 'tokenizer') and model.tokenizer else AutoTokenizer.from_pretrained(model_path) dataset = BGERerankerDataset(data_path, tokenizer, is_train=False) dataloader = DataLoader(dataset, batch_size=batch_size) model.eval() model.to(device) n_samples = 0 times = [] all_scores = [] all_labels = [] tracemalloc.start() with torch.no_grad(): for batch in dataloader: if n_samples >= max_samples: break start = time.time() outputs = model.model( input_ids=batch['input_ids'].to(device), attention_mask=batch['attention_mask'].to(device) ) torch.cuda.synchronize() if device.startswith('cuda') else None end = time.time() times.append(end - start) n = batch['input_ids'].shape[0] n_samples += n # For metrics: collect scores and labels scores = outputs.logits if hasattr(outputs, 'logits') else outputs[0] all_scores.append(scores.cpu().numpy()) all_labels.append(batch['labels'].cpu().numpy()) current, peak = tracemalloc.get_traced_memory() tracemalloc.stop() throughput = n_samples / sum(times) latency = np.mean(times) / batch_size mem_mb = peak / (1024 * 1024) # Prepare for metrics all_scores = np.concatenate(all_scores, axis=0) all_labels = np.concatenate(all_labels, axis=0) return throughput, latency, mem_mb, all_scores, all_labels def main(): """Run reranker comparison with robust error handling and config-driven defaults.""" parser = argparse.ArgumentParser(description="Compare reranker models.") parser.add_argument('--device', type=str, default=None, help='Device to use (cuda, npu, cpu)') parser.add_argument('--finetuned_model_path', type=str, required=True, help='Path to fine-tuned reranker model') parser.add_argument('--baseline_model_path', type=str, required=True, help='Path to baseline reranker model') parser.add_argument('--data_path', type=str, required=True, help='Path to evaluation data (JSONL)') parser.add_argument('--batch_size', type=int, default=16, help='Batch size for inference') parser.add_argument('--max_samples', type=int, default=1000, help='Max samples to benchmark (for speed)') parser.add_argument('--output', type=str, default='compare_reranker_results.txt', help='File to save comparison results') parser.add_argument('--k_values', type=int, nargs='+', default=[1, 5, 10], help='K values for metrics@k') args = parser.parse_args() # Load config and set defaults config = get_config() device = args.device or config.get('hardware', {}).get('device', 'cuda') # type: ignore[attr-defined] try: # Fine-tuned model logger.info("Running fine-tuned reranker...") ft_throughput, ft_latency, ft_mem, ft_scores, ft_labels = run_reranker( args.finetuned_model_path, args.data_path, args.batch_size, args.max_samples, device) ft_metrics = compute_reranker_metrics(ft_scores, ft_labels, k_values=args.k_values) # Baseline model logger.info("Running baseline reranker...") bl_throughput, bl_latency, bl_mem, bl_scores, bl_labels = run_reranker( args.baseline_model_path, args.data_path, args.batch_size, args.max_samples, device) bl_metrics = compute_reranker_metrics(bl_scores, bl_labels, k_values=args.k_values) # Output comparison with open(args.output, 'w') as f: f.write(f"Metric\tBaseline\tFine-tuned\tDelta\n") for k in args.k_values: for metric in [f'accuracy@{k}', f'map@{k}', f'mrr@{k}', f'ndcg@{k}']: bl = bl_metrics.get(metric, 0) ft = ft_metrics.get(metric, 0) delta = ft - bl f.write(f"{metric}\t{bl:.4f}\t{ft:.4f}\t{delta:+.4f}\n") for metric in ['map', 'mrr']: bl = bl_metrics.get(metric, 0) ft = ft_metrics.get(metric, 0) delta = ft - bl f.write(f"{metric}\t{bl:.4f}\t{ft:.4f}\t{delta:+.4f}\n") f.write(f"Throughput (samples/sec)\t{bl_throughput:.2f}\t{ft_throughput:.2f}\t{ft_throughput-bl_throughput:+.2f}\n") f.write(f"Latency (ms/sample)\t{bl_latency*1000:.2f}\t{ft_latency*1000:.2f}\t{(ft_latency-bl_latency)*1000:+.2f}\n") f.write(f"Peak memory (MB)\t{bl_mem:.2f}\t{ft_mem:.2f}\t{ft_mem-bl_mem:+.2f}\n") logger.info(f"Comparison results saved to {args.output}") print(f"\nComparison complete. See {args.output} for details.") except Exception as e: logging.error(f"Reranker comparison failed: {e}") raise if __name__ == '__main__': main()