init commit
This commit is contained in:
138
scripts/compare_reranker.py
Normal file
138
scripts/compare_reranker.py
Normal file
@@ -0,0 +1,138 @@
|
||||
import os
|
||||
import time
|
||||
import argparse
|
||||
import logging
|
||||
import torch
|
||||
import psutil
|
||||
import tracemalloc
|
||||
import numpy as np
|
||||
from utils.config_loader import get_config
|
||||
from models.bge_reranker import BGERerankerModel
|
||||
from data.dataset import BGERerankerDataset
|
||||
from transformers import AutoTokenizer
|
||||
from evaluation.metrics import compute_reranker_metrics
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger("compare_reranker")
|
||||
|
||||
def parse_args():
|
||||
"""Parse command-line arguments for reranker comparison."""
|
||||
parser = argparse.ArgumentParser(description="Compare fine-tuned and baseline reranker models.")
|
||||
parser.add_argument('--finetuned_model_path', type=str, required=True, help='Path to fine-tuned reranker model')
|
||||
parser.add_argument('--baseline_model_path', type=str, required=True, help='Path to baseline reranker model')
|
||||
parser.add_argument('--data_path', type=str, required=True, help='Path to evaluation data (JSONL)')
|
||||
parser.add_argument('--batch_size', type=int, default=16, help='Batch size for inference')
|
||||
parser.add_argument('--max_samples', type=int, default=1000, help='Max samples to benchmark (for speed)')
|
||||
parser.add_argument('--device', type=str, default='cuda', help='Device to use (cuda, cpu, npu)')
|
||||
parser.add_argument('--output', type=str, default='compare_reranker_results.txt', help='File to save comparison results')
|
||||
parser.add_argument('--k_values', type=int, nargs='+', default=[1, 5, 10], help='K values for metrics@k')
|
||||
return parser.parse_args()
|
||||
|
||||
def measure_memory():
|
||||
"""Measure the current memory usage of the process."""
|
||||
process = psutil.Process(os.getpid())
|
||||
return process.memory_info().rss / (1024 * 1024) # MB
|
||||
|
||||
def run_reranker(model_path, data_path, batch_size, max_samples, device):
|
||||
"""
|
||||
Run inference on a reranker model and measure throughput, latency, and memory.
|
||||
Args:
|
||||
model_path (str): Path to the reranker model.
|
||||
data_path (str): Path to the evaluation data.
|
||||
batch_size (int): Batch size for inference.
|
||||
max_samples (int): Maximum number of samples to process.
|
||||
device (str): Device to use (cuda, cpu, npu).
|
||||
Returns:
|
||||
tuple: (throughput, latency, mem_mb, all_scores, all_labels)
|
||||
"""
|
||||
model = BGERerankerModel(model_name_or_path=model_path, device=device)
|
||||
tokenizer = model.tokenizer if hasattr(model, 'tokenizer') and model.tokenizer else AutoTokenizer.from_pretrained(model_path)
|
||||
dataset = BGERerankerDataset(data_path, tokenizer, is_train=False)
|
||||
dataloader = DataLoader(dataset, batch_size=batch_size)
|
||||
model.eval()
|
||||
model.to(device)
|
||||
n_samples = 0
|
||||
times = []
|
||||
all_scores = []
|
||||
all_labels = []
|
||||
tracemalloc.start()
|
||||
with torch.no_grad():
|
||||
for batch in dataloader:
|
||||
if n_samples >= max_samples:
|
||||
break
|
||||
start = time.time()
|
||||
outputs = model.model(
|
||||
input_ids=batch['input_ids'].to(device),
|
||||
attention_mask=batch['attention_mask'].to(device)
|
||||
)
|
||||
torch.cuda.synchronize() if device.startswith('cuda') else None
|
||||
end = time.time()
|
||||
times.append(end - start)
|
||||
n = batch['input_ids'].shape[0]
|
||||
n_samples += n
|
||||
# For metrics: collect scores and labels
|
||||
scores = outputs.logits if hasattr(outputs, 'logits') else outputs[0]
|
||||
all_scores.append(scores.cpu().numpy())
|
||||
all_labels.append(batch['labels'].cpu().numpy())
|
||||
current, peak = tracemalloc.get_traced_memory()
|
||||
tracemalloc.stop()
|
||||
throughput = n_samples / sum(times)
|
||||
latency = np.mean(times) / batch_size
|
||||
mem_mb = peak / (1024 * 1024)
|
||||
# Prepare for metrics
|
||||
all_scores = np.concatenate(all_scores, axis=0)
|
||||
all_labels = np.concatenate(all_labels, axis=0)
|
||||
return throughput, latency, mem_mb, all_scores, all_labels
|
||||
|
||||
def main():
|
||||
"""Run reranker comparison with robust error handling and config-driven defaults."""
|
||||
parser = argparse.ArgumentParser(description="Compare reranker models.")
|
||||
parser.add_argument('--device', type=str, default=None, help='Device to use (cuda, npu, cpu)')
|
||||
parser.add_argument('--finetuned_model_path', type=str, required=True, help='Path to fine-tuned reranker model')
|
||||
parser.add_argument('--baseline_model_path', type=str, required=True, help='Path to baseline reranker model')
|
||||
parser.add_argument('--data_path', type=str, required=True, help='Path to evaluation data (JSONL)')
|
||||
parser.add_argument('--batch_size', type=int, default=16, help='Batch size for inference')
|
||||
parser.add_argument('--max_samples', type=int, default=1000, help='Max samples to benchmark (for speed)')
|
||||
parser.add_argument('--output', type=str, default='compare_reranker_results.txt', help='File to save comparison results')
|
||||
parser.add_argument('--k_values', type=int, nargs='+', default=[1, 5, 10], help='K values for metrics@k')
|
||||
args = parser.parse_args()
|
||||
# Load config and set defaults
|
||||
config = get_config()
|
||||
device = args.device or config.get('hardware', {}).get('device', 'cuda') # type: ignore[attr-defined]
|
||||
try:
|
||||
# Fine-tuned model
|
||||
logger.info("Running fine-tuned reranker...")
|
||||
ft_throughput, ft_latency, ft_mem, ft_scores, ft_labels = run_reranker(
|
||||
args.finetuned_model_path, args.data_path, args.batch_size, args.max_samples, device)
|
||||
ft_metrics = compute_reranker_metrics(ft_scores, ft_labels, k_values=args.k_values)
|
||||
# Baseline model
|
||||
logger.info("Running baseline reranker...")
|
||||
bl_throughput, bl_latency, bl_mem, bl_scores, bl_labels = run_reranker(
|
||||
args.baseline_model_path, args.data_path, args.batch_size, args.max_samples, device)
|
||||
bl_metrics = compute_reranker_metrics(bl_scores, bl_labels, k_values=args.k_values)
|
||||
# Output comparison
|
||||
with open(args.output, 'w') as f:
|
||||
f.write(f"Metric\tBaseline\tFine-tuned\tDelta\n")
|
||||
for k in args.k_values:
|
||||
for metric in [f'accuracy@{k}', f'map@{k}', f'mrr@{k}', f'ndcg@{k}']:
|
||||
bl = bl_metrics.get(metric, 0)
|
||||
ft = ft_metrics.get(metric, 0)
|
||||
delta = ft - bl
|
||||
f.write(f"{metric}\t{bl:.4f}\t{ft:.4f}\t{delta:+.4f}\n")
|
||||
for metric in ['map', 'mrr']:
|
||||
bl = bl_metrics.get(metric, 0)
|
||||
ft = ft_metrics.get(metric, 0)
|
||||
delta = ft - bl
|
||||
f.write(f"{metric}\t{bl:.4f}\t{ft:.4f}\t{delta:+.4f}\n")
|
||||
f.write(f"Throughput (samples/sec)\t{bl_throughput:.2f}\t{ft_throughput:.2f}\t{ft_throughput-bl_throughput:+.2f}\n")
|
||||
f.write(f"Latency (ms/sample)\t{bl_latency*1000:.2f}\t{ft_latency*1000:.2f}\t{(ft_latency-bl_latency)*1000:+.2f}\n")
|
||||
f.write(f"Peak memory (MB)\t{bl_mem:.2f}\t{ft_mem:.2f}\t{ft_mem-bl_mem:+.2f}\n")
|
||||
logger.info(f"Comparison results saved to {args.output}")
|
||||
print(f"\nComparison complete. See {args.output} for details.")
|
||||
except Exception as e:
|
||||
logging.error(f"Reranker comparison failed: {e}")
|
||||
raise
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
Reference in New Issue
Block a user