bge_finetune/scripts/benchmark.py
2025-07-22 16:55:25 +08:00

145 lines
6.9 KiB
Python

import os
import time
import argparse
import logging
import torch
import psutil
import tracemalloc
import numpy as np
from utils.config_loader import get_config
from models.bge_m3 import BGEM3Model
from models.bge_reranker import BGERerankerModel
from data.dataset import BGEM3Dataset, BGERerankerDataset
from transformers import AutoTokenizer
from torch.utils.data import DataLoader
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("benchmark")
def parse_args():
"""Parse command line arguments for benchmarking."""
parser = argparse.ArgumentParser(description="Benchmark BGE fine-tuned models.")
parser.add_argument('--model_type', type=str, choices=['retriever', 'reranker'], required=True, help='Model type to benchmark')
parser.add_argument('--model_path', type=str, required=True, help='Path to fine-tuned model')
parser.add_argument('--data_path', type=str, required=True, help='Path to evaluation data (JSONL)')
parser.add_argument('--batch_size', type=int, default=16, help='Batch size for inference')
parser.add_argument('--max_samples', type=int, default=1000, help='Max samples to benchmark (for speed)')
parser.add_argument('--device', type=str, default='cuda', help='Device to use (cuda, cpu, npu)')
parser.add_argument('--output', type=str, default='benchmark_results.txt', help='File to save benchmark results')
return parser.parse_args()
def measure_memory():
"""Measure the current memory usage of the process."""
process = psutil.Process(os.getpid())
return process.memory_info().rss / (1024 * 1024) # MB
def benchmark_retriever(model_path, data_path, batch_size, max_samples, device):
"""Benchmark the BGE retriever model."""
logger.info(f"Loading retriever model from {model_path}")
model = BGEM3Model(model_name_or_path=model_path, device=device)
tokenizer = model.tokenizer if hasattr(model, 'tokenizer') and model.tokenizer else AutoTokenizer.from_pretrained(model_path)
dataset = BGEM3Dataset(data_path, tokenizer, is_train=False)
logger.info(f"Loaded {len(dataset)} samples for benchmarking")
dataloader = DataLoader(dataset, batch_size=batch_size)
model.eval()
model.to(device)
n_samples = 0
times = []
tracemalloc.start()
with torch.no_grad():
for batch in dataloader:
if n_samples >= max_samples:
break
start = time.time()
# Only dense embedding for speed
_ = model(
batch['query_input_ids'].to(device),
batch['query_attention_mask'].to(device),
return_dense=True, return_sparse=False, return_colbert=False
)
torch.cuda.synchronize() if device.startswith('cuda') else None
end = time.time()
times.append(end - start)
n_samples += batch['query_input_ids'].shape[0]
current, peak = tracemalloc.get_traced_memory()
tracemalloc.stop()
throughput = n_samples / sum(times)
latency = np.mean(times) / batch_size
mem_mb = peak / (1024 * 1024)
logger.info(f"Retriever throughput: {throughput:.2f} samples/sec, latency: {latency*1000:.2f} ms/sample, peak mem: {mem_mb:.2f} MB")
return throughput, latency, mem_mb
def benchmark_reranker(model_path, data_path, batch_size, max_samples, device):
"""Benchmark the BGE reranker model."""
logger.info(f"Loading reranker model from {model_path}")
model = BGERerankerModel(model_name_or_path=model_path, device=device)
tokenizer = model.tokenizer if hasattr(model, 'tokenizer') and model.tokenizer else AutoTokenizer.from_pretrained(model_path)
dataset = BGERerankerDataset(data_path, tokenizer, is_train=False)
logger.info(f"Loaded {len(dataset)} samples for benchmarking")
dataloader = DataLoader(dataset, batch_size=batch_size)
model.eval()
model.to(device)
n_samples = 0
times = []
tracemalloc.start()
with torch.no_grad():
for batch in dataloader:
if n_samples >= max_samples:
break
start = time.time()
_ = model.model(
input_ids=batch['input_ids'].to(device),
attention_mask=batch['attention_mask'].to(device)
)
torch.cuda.synchronize() if device.startswith('cuda') else None
end = time.time()
times.append(end - start)
n_samples += batch['input_ids'].shape[0]
current, peak = tracemalloc.get_traced_memory()
tracemalloc.stop()
throughput = n_samples / sum(times)
latency = np.mean(times) / batch_size
mem_mb = peak / (1024 * 1024)
logger.info(f"Reranker throughput: {throughput:.2f} samples/sec, latency: {latency*1000:.2f} ms/sample, peak mem: {mem_mb:.2f} MB")
return throughput, latency, mem_mb
def main():
"""Run benchmark with robust error handling and config-driven defaults."""
import argparse
parser = argparse.ArgumentParser(description="Benchmark BGE models.")
parser.add_argument('--device', type=str, default=None, help='Device to use (cuda, npu, cpu)')
parser.add_argument('--model_type', type=str, choices=['retriever', 'reranker'], required=True, help='Model type to benchmark')
parser.add_argument('--model_path', type=str, required=True, help='Path to fine-tuned model')
parser.add_argument('--data_path', type=str, required=True, help='Path to evaluation data (JSONL)')
parser.add_argument('--batch_size', type=int, default=16, help='Batch size for inference')
parser.add_argument('--max_samples', type=int, default=1000, help='Max samples to benchmark (for speed)')
parser.add_argument('--output', type=str, default='benchmark_results.txt', help='File to save benchmark results')
args = parser.parse_args()
# Load config and set defaults
from utils.config_loader import get_config
config = get_config()
device = args.device or config.get('hardware', {}).get('device', 'cuda') # type: ignore[attr-defined]
try:
if args.model_type == 'retriever':
throughput, latency, mem_mb = benchmark_retriever(
args.model_path, args.data_path, args.batch_size, args.max_samples, device)
else:
throughput, latency, mem_mb = benchmark_reranker(
args.model_path, args.data_path, args.batch_size, args.max_samples, device)
# Save results
with open(args.output, 'w') as f:
f.write(f"Model: {args.model_path}\n")
f.write(f"Data: {args.data_path}\n")
f.write(f"Batch size: {args.batch_size}\n")
f.write(f"Throughput: {throughput:.2f} samples/sec\n")
f.write(f"Latency: {latency*1000:.2f} ms/sample\n")
f.write(f"Peak memory: {mem_mb:.2f} MB\n")
logger.info(f"Benchmark results saved to {args.output}")
logging.info(f"Benchmark completed on device: {device}")
except Exception as e:
logging.error(f"Benchmark failed: {e}")
raise
if __name__ == '__main__':
main()