import os import time import argparse import logging import torch import psutil import tracemalloc import numpy as np from utils.config_loader import get_config from models.bge_m3 import BGEM3Model from models.bge_reranker import BGERerankerModel from data.dataset import BGEM3Dataset, BGERerankerDataset from transformers import AutoTokenizer from torch.utils.data import DataLoader logging.basicConfig(level=logging.INFO) logger = logging.getLogger("benchmark") def parse_args(): """Parse command line arguments for benchmarking.""" parser = argparse.ArgumentParser(description="Benchmark BGE fine-tuned models.") parser.add_argument('--model_type', type=str, choices=['retriever', 'reranker'], required=True, help='Model type to benchmark') parser.add_argument('--model_path', type=str, required=True, help='Path to fine-tuned model') parser.add_argument('--data_path', type=str, required=True, help='Path to evaluation data (JSONL)') parser.add_argument('--batch_size', type=int, default=16, help='Batch size for inference') parser.add_argument('--max_samples', type=int, default=1000, help='Max samples to benchmark (for speed)') parser.add_argument('--device', type=str, default='cuda', help='Device to use (cuda, cpu, npu)') parser.add_argument('--output', type=str, default='benchmark_results.txt', help='File to save benchmark results') return parser.parse_args() def measure_memory(): """Measure the current memory usage of the process.""" process = psutil.Process(os.getpid()) return process.memory_info().rss / (1024 * 1024) # MB def benchmark_retriever(model_path, data_path, batch_size, max_samples, device): """Benchmark the BGE retriever model.""" logger.info(f"Loading retriever model from {model_path}") model = BGEM3Model(model_name_or_path=model_path, device=device) tokenizer = model.tokenizer if hasattr(model, 'tokenizer') and model.tokenizer else AutoTokenizer.from_pretrained(model_path) dataset = BGEM3Dataset(data_path, tokenizer, is_train=False) logger.info(f"Loaded {len(dataset)} samples for benchmarking") dataloader = DataLoader(dataset, batch_size=batch_size) model.eval() model.to(device) n_samples = 0 times = [] tracemalloc.start() with torch.no_grad(): for batch in dataloader: if n_samples >= max_samples: break start = time.time() # Only dense embedding for speed _ = model( batch['query_input_ids'].to(device), batch['query_attention_mask'].to(device), return_dense=True, return_sparse=False, return_colbert=False ) torch.cuda.synchronize() if device.startswith('cuda') else None end = time.time() times.append(end - start) n_samples += batch['query_input_ids'].shape[0] current, peak = tracemalloc.get_traced_memory() tracemalloc.stop() throughput = n_samples / sum(times) latency = np.mean(times) / batch_size mem_mb = peak / (1024 * 1024) logger.info(f"Retriever throughput: {throughput:.2f} samples/sec, latency: {latency*1000:.2f} ms/sample, peak mem: {mem_mb:.2f} MB") return throughput, latency, mem_mb def benchmark_reranker(model_path, data_path, batch_size, max_samples, device): """Benchmark the BGE reranker model.""" logger.info(f"Loading reranker model from {model_path}") model = BGERerankerModel(model_name_or_path=model_path, device=device) tokenizer = model.tokenizer if hasattr(model, 'tokenizer') and model.tokenizer else AutoTokenizer.from_pretrained(model_path) dataset = BGERerankerDataset(data_path, tokenizer, is_train=False) logger.info(f"Loaded {len(dataset)} samples for benchmarking") dataloader = DataLoader(dataset, batch_size=batch_size) model.eval() model.to(device) n_samples = 0 times = [] tracemalloc.start() with torch.no_grad(): for batch in dataloader: if n_samples >= max_samples: break start = time.time() _ = model.model( input_ids=batch['input_ids'].to(device), attention_mask=batch['attention_mask'].to(device) ) torch.cuda.synchronize() if device.startswith('cuda') else None end = time.time() times.append(end - start) n_samples += batch['input_ids'].shape[0] current, peak = tracemalloc.get_traced_memory() tracemalloc.stop() throughput = n_samples / sum(times) latency = np.mean(times) / batch_size mem_mb = peak / (1024 * 1024) logger.info(f"Reranker throughput: {throughput:.2f} samples/sec, latency: {latency*1000:.2f} ms/sample, peak mem: {mem_mb:.2f} MB") return throughput, latency, mem_mb def main(): """Run benchmark with robust error handling and config-driven defaults.""" import argparse parser = argparse.ArgumentParser(description="Benchmark BGE models.") parser.add_argument('--device', type=str, default=None, help='Device to use (cuda, npu, cpu)') parser.add_argument('--model_type', type=str, choices=['retriever', 'reranker'], required=True, help='Model type to benchmark') parser.add_argument('--model_path', type=str, required=True, help='Path to fine-tuned model') parser.add_argument('--data_path', type=str, required=True, help='Path to evaluation data (JSONL)') parser.add_argument('--batch_size', type=int, default=16, help='Batch size for inference') parser.add_argument('--max_samples', type=int, default=1000, help='Max samples to benchmark (for speed)') parser.add_argument('--output', type=str, default='benchmark_results.txt', help='File to save benchmark results') args = parser.parse_args() # Load config and set defaults from utils.config_loader import get_config config = get_config() device = args.device or config.get('hardware', {}).get('device', 'cuda') # type: ignore[attr-defined] try: if args.model_type == 'retriever': throughput, latency, mem_mb = benchmark_retriever( args.model_path, args.data_path, args.batch_size, args.max_samples, device) else: throughput, latency, mem_mb = benchmark_reranker( args.model_path, args.data_path, args.batch_size, args.max_samples, device) # Save results with open(args.output, 'w') as f: f.write(f"Model: {args.model_path}\n") f.write(f"Data: {args.data_path}\n") f.write(f"Batch size: {args.batch_size}\n") f.write(f"Throughput: {throughput:.2f} samples/sec\n") f.write(f"Latency: {latency*1000:.2f} ms/sample\n") f.write(f"Peak memory: {mem_mb:.2f} MB\n") logger.info(f"Benchmark results saved to {args.output}") logging.info(f"Benchmark completed on device: {device}") except Exception as e: logging.error(f"Benchmark failed: {e}") raise if __name__ == '__main__': main()