611 lines
24 KiB
Python
611 lines
24 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Unified BGE Model Comparison Script
|
|
|
|
This script compares fine-tuned BGE models (retriever/reranker) against baseline models,
|
|
providing both accuracy metrics and performance comparisons.
|
|
|
|
Usage:
|
|
# Compare retriever models
|
|
python scripts/compare_models.py \
|
|
--model_type retriever \
|
|
--finetuned_model ./output/bge_m3_三国演义/final_model \
|
|
--baseline_model BAAI/bge-m3 \
|
|
--data_path data/datasets/三国演义/splits/m3_test.jsonl
|
|
|
|
# Compare reranker models
|
|
python scripts/compare_models.py \
|
|
--model_type reranker \
|
|
--finetuned_model ./output/bge_reranker_三国演义/final_model \
|
|
--baseline_model BAAI/bge-reranker-base \
|
|
--data_path data/datasets/三国演义/splits/reranker_test.jsonl
|
|
|
|
# Compare both with comprehensive output
|
|
python scripts/compare_models.py \
|
|
--model_type both \
|
|
--finetuned_retriever ./output/bge_m3_三国演义/final_model \
|
|
--finetuned_reranker ./output/bge_reranker_三国演义/final_model \
|
|
--baseline_retriever BAAI/bge-m3 \
|
|
--baseline_reranker BAAI/bge-reranker-base \
|
|
--retriever_data data/datasets/三国演义/splits/m3_test.jsonl \
|
|
--reranker_data data/datasets/三国演义/splits/reranker_test.jsonl
|
|
"""
|
|
|
|
import os
|
|
import sys
|
|
import time
|
|
import argparse
|
|
import logging
|
|
import json
|
|
from pathlib import Path
|
|
from typing import Dict, Tuple, Optional, Any
|
|
import torch
|
|
import psutil
|
|
import numpy as np
|
|
from transformers import AutoTokenizer
|
|
from torch.utils.data import DataLoader
|
|
|
|
# Add project root to path
|
|
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
|
|
|
from models.bge_m3 import BGEM3Model
|
|
from models.bge_reranker import BGERerankerModel
|
|
from data.dataset import BGEM3Dataset, BGERerankerDataset
|
|
from evaluation.metrics import compute_retrieval_metrics, compute_reranker_metrics
|
|
from utils.config_loader import get_config
|
|
|
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class ModelComparator:
|
|
"""Unified model comparison class for both retriever and reranker models"""
|
|
|
|
def __init__(self, device: str = 'auto'):
|
|
self.device = self._get_device(device)
|
|
|
|
def _get_device(self, device: str) -> torch.device:
|
|
"""Get optimal device"""
|
|
if device == 'auto':
|
|
try:
|
|
config = get_config()
|
|
device_type = config.get('hardware', {}).get('device', 'cuda')
|
|
|
|
if device_type == 'npu' and torch.cuda.is_available():
|
|
try:
|
|
import torch_npu
|
|
return torch.device("npu:0")
|
|
except ImportError:
|
|
pass
|
|
|
|
if torch.cuda.is_available():
|
|
return torch.device("cuda")
|
|
|
|
except Exception:
|
|
pass
|
|
|
|
return torch.device("cpu")
|
|
else:
|
|
return torch.device(device)
|
|
|
|
def measure_memory(self) -> float:
|
|
"""Measure current memory usage in MB"""
|
|
process = psutil.Process(os.getpid())
|
|
return process.memory_info().rss / (1024 * 1024)
|
|
|
|
def compare_retrievers(
|
|
self,
|
|
finetuned_path: str,
|
|
baseline_path: str,
|
|
data_path: str,
|
|
batch_size: int = 16,
|
|
max_samples: Optional[int] = None,
|
|
k_values: list = [1, 5, 10, 20, 100]
|
|
) -> Dict[str, Any]:
|
|
"""Compare retriever models"""
|
|
logger.info("🔍 Comparing Retriever Models")
|
|
logger.info(f"Fine-tuned: {finetuned_path}")
|
|
logger.info(f"Baseline: {baseline_path}")
|
|
|
|
# Test both models
|
|
ft_results = self._benchmark_retriever(finetuned_path, data_path, batch_size, max_samples, k_values)
|
|
baseline_results = self._benchmark_retriever(baseline_path, data_path, batch_size, max_samples, k_values)
|
|
|
|
# Compute improvements
|
|
comparison = self._compute_metric_improvements(baseline_results['metrics'], ft_results['metrics'])
|
|
|
|
return {
|
|
'model_type': 'retriever',
|
|
'finetuned': ft_results,
|
|
'baseline': baseline_results,
|
|
'comparison': comparison,
|
|
'summary': self._generate_retriever_summary(comparison)
|
|
}
|
|
|
|
def compare_rerankers(
|
|
self,
|
|
finetuned_path: str,
|
|
baseline_path: str,
|
|
data_path: str,
|
|
batch_size: int = 32,
|
|
max_samples: Optional[int] = None,
|
|
k_values: list = [1, 5, 10]
|
|
) -> Dict[str, Any]:
|
|
"""Compare reranker models"""
|
|
logger.info("🏆 Comparing Reranker Models")
|
|
logger.info(f"Fine-tuned: {finetuned_path}")
|
|
logger.info(f"Baseline: {baseline_path}")
|
|
|
|
# Test both models
|
|
ft_results = self._benchmark_reranker(finetuned_path, data_path, batch_size, max_samples, k_values)
|
|
baseline_results = self._benchmark_reranker(baseline_path, data_path, batch_size, max_samples, k_values)
|
|
|
|
# Compute improvements
|
|
comparison = self._compute_metric_improvements(baseline_results['metrics'], ft_results['metrics'])
|
|
|
|
return {
|
|
'model_type': 'reranker',
|
|
'finetuned': ft_results,
|
|
'baseline': baseline_results,
|
|
'comparison': comparison,
|
|
'summary': self._generate_reranker_summary(comparison)
|
|
}
|
|
|
|
def _benchmark_retriever(
|
|
self,
|
|
model_path: str,
|
|
data_path: str,
|
|
batch_size: int,
|
|
max_samples: Optional[int],
|
|
k_values: list
|
|
) -> Dict[str, Any]:
|
|
"""Benchmark a retriever model"""
|
|
logger.info(f"Benchmarking retriever: {model_path}")
|
|
|
|
# Load model
|
|
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
|
model = BGEM3Model.from_pretrained(model_path)
|
|
model.to(self.device)
|
|
model.eval()
|
|
|
|
# Load data
|
|
dataset = BGEM3Dataset(
|
|
data_path=data_path,
|
|
tokenizer=tokenizer,
|
|
is_train=False
|
|
)
|
|
|
|
if max_samples and len(dataset) > max_samples:
|
|
dataset.data = dataset.data[:max_samples]
|
|
|
|
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
|
|
logger.info(f"Testing on {len(dataset)} samples")
|
|
|
|
# Benchmark
|
|
start_memory = self.measure_memory()
|
|
start_time = time.time()
|
|
|
|
all_query_embeddings = []
|
|
all_passage_embeddings = []
|
|
all_labels = []
|
|
|
|
with torch.no_grad():
|
|
for batch in dataloader:
|
|
# Move to device
|
|
batch = {k: v.to(self.device) if isinstance(v, torch.Tensor) else v
|
|
for k, v in batch.items()}
|
|
|
|
# Get embeddings
|
|
query_outputs = model.forward(
|
|
input_ids=batch['query_input_ids'],
|
|
attention_mask=batch['query_attention_mask'],
|
|
return_dense=True,
|
|
return_dict=True
|
|
)
|
|
|
|
passage_outputs = model.forward(
|
|
input_ids=batch['passage_input_ids'],
|
|
attention_mask=batch['passage_attention_mask'],
|
|
return_dense=True,
|
|
return_dict=True
|
|
)
|
|
|
|
all_query_embeddings.append(query_outputs['dense'].cpu().numpy())
|
|
all_passage_embeddings.append(passage_outputs['dense'].cpu().numpy())
|
|
all_labels.append(batch['labels'].cpu().numpy())
|
|
|
|
end_time = time.time()
|
|
peak_memory = self.measure_memory()
|
|
|
|
# Compute metrics
|
|
query_embeds = np.concatenate(all_query_embeddings, axis=0)
|
|
passage_embeds = np.concatenate(all_passage_embeddings, axis=0)
|
|
labels = np.concatenate(all_labels, axis=0)
|
|
|
|
metrics = compute_retrieval_metrics(
|
|
query_embeds, passage_embeds, labels, k_values=k_values
|
|
)
|
|
|
|
# Performance stats
|
|
total_time = end_time - start_time
|
|
throughput = len(dataset) / total_time
|
|
latency = (total_time * 1000) / len(dataset) # ms per sample
|
|
memory_usage = peak_memory - start_memory
|
|
|
|
return {
|
|
'model_path': model_path,
|
|
'metrics': metrics,
|
|
'performance': {
|
|
'throughput_samples_per_sec': throughput,
|
|
'latency_ms_per_sample': latency,
|
|
'memory_usage_mb': memory_usage,
|
|
'total_time_sec': total_time,
|
|
'samples_processed': len(dataset)
|
|
}
|
|
}
|
|
|
|
def _benchmark_reranker(
|
|
self,
|
|
model_path: str,
|
|
data_path: str,
|
|
batch_size: int,
|
|
max_samples: Optional[int],
|
|
k_values: list
|
|
) -> Dict[str, Any]:
|
|
"""Benchmark a reranker model"""
|
|
logger.info(f"Benchmarking reranker: {model_path}")
|
|
|
|
# Load model
|
|
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
|
model = BGERerankerModel.from_pretrained(model_path)
|
|
model.to(self.device)
|
|
model.eval()
|
|
|
|
# Load data
|
|
dataset = BGERerankerDataset(
|
|
data_path=data_path,
|
|
tokenizer=tokenizer,
|
|
is_train=False
|
|
)
|
|
|
|
if max_samples and len(dataset) > max_samples:
|
|
dataset.data = dataset.data[:max_samples]
|
|
|
|
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
|
|
logger.info(f"Testing on {len(dataset)} samples")
|
|
|
|
# Benchmark
|
|
start_memory = self.measure_memory()
|
|
start_time = time.time()
|
|
|
|
all_scores = []
|
|
all_labels = []
|
|
|
|
with torch.no_grad():
|
|
for batch in dataloader:
|
|
# Move to device
|
|
batch = {k: v.to(self.device) if isinstance(v, torch.Tensor) else v
|
|
for k, v in batch.items()}
|
|
|
|
# Get scores
|
|
outputs = model.forward(
|
|
input_ids=batch['input_ids'],
|
|
attention_mask=batch['attention_mask'],
|
|
return_dict=True
|
|
)
|
|
|
|
logits = outputs.logits if hasattr(outputs, 'logits') else outputs.scores
|
|
scores = torch.sigmoid(logits.squeeze()).cpu().numpy()
|
|
|
|
all_scores.append(scores)
|
|
all_labels.append(batch['labels'].cpu().numpy())
|
|
|
|
end_time = time.time()
|
|
peak_memory = self.measure_memory()
|
|
|
|
# Compute metrics
|
|
scores = np.concatenate(all_scores, axis=0)
|
|
labels = np.concatenate(all_labels, axis=0)
|
|
|
|
metrics = compute_reranker_metrics(scores, labels, k_values=k_values)
|
|
|
|
# Performance stats
|
|
total_time = end_time - start_time
|
|
throughput = len(dataset) / total_time
|
|
latency = (total_time * 1000) / len(dataset) # ms per sample
|
|
memory_usage = peak_memory - start_memory
|
|
|
|
return {
|
|
'model_path': model_path,
|
|
'metrics': metrics,
|
|
'performance': {
|
|
'throughput_samples_per_sec': throughput,
|
|
'latency_ms_per_sample': latency,
|
|
'memory_usage_mb': memory_usage,
|
|
'total_time_sec': total_time,
|
|
'samples_processed': len(dataset)
|
|
}
|
|
}
|
|
|
|
def _compute_metric_improvements(self, baseline_metrics: Dict, finetuned_metrics: Dict) -> Dict[str, float]:
|
|
"""Compute percentage improvements for each metric"""
|
|
improvements = {}
|
|
|
|
for metric_name in baseline_metrics:
|
|
baseline_val = baseline_metrics[metric_name]
|
|
finetuned_val = finetuned_metrics.get(metric_name, baseline_val)
|
|
|
|
if baseline_val == 0:
|
|
improvement = 0.0 if finetuned_val == 0 else float('inf')
|
|
else:
|
|
improvement = ((finetuned_val - baseline_val) / baseline_val) * 100
|
|
|
|
improvements[metric_name] = improvement
|
|
|
|
return improvements
|
|
|
|
def _generate_retriever_summary(self, comparison: Dict[str, float]) -> Dict[str, Any]:
|
|
"""Generate summary for retriever comparison"""
|
|
key_metrics = ['recall@5', 'recall@10', 'map', 'mrr']
|
|
improvements = [comparison.get(metric, 0) for metric in key_metrics if metric in comparison]
|
|
|
|
avg_improvement = np.mean(improvements) if improvements else 0
|
|
positive_improvements = sum(1 for imp in improvements if imp > 0)
|
|
|
|
if avg_improvement > 5:
|
|
status = "🌟 EXCELLENT"
|
|
elif avg_improvement > 2:
|
|
status = "✅ GOOD"
|
|
elif avg_improvement > 0:
|
|
status = "👌 FAIR"
|
|
else:
|
|
status = "❌ POOR"
|
|
|
|
return {
|
|
'status': status,
|
|
'avg_improvement_pct': avg_improvement,
|
|
'positive_improvements': positive_improvements,
|
|
'total_metrics': len(improvements),
|
|
'key_improvements': {metric: comparison.get(metric, 0) for metric in key_metrics if metric in comparison}
|
|
}
|
|
|
|
def _generate_reranker_summary(self, comparison: Dict[str, float]) -> Dict[str, Any]:
|
|
"""Generate summary for reranker comparison"""
|
|
key_metrics = ['accuracy', 'precision', 'recall', 'f1', 'auc']
|
|
improvements = [comparison.get(metric, 0) for metric in key_metrics if metric in comparison]
|
|
|
|
avg_improvement = np.mean(improvements) if improvements else 0
|
|
positive_improvements = sum(1 for imp in improvements if imp > 0)
|
|
|
|
if avg_improvement > 3:
|
|
status = "🌟 EXCELLENT"
|
|
elif avg_improvement > 1:
|
|
status = "✅ GOOD"
|
|
elif avg_improvement > 0:
|
|
status = "👌 FAIR"
|
|
else:
|
|
status = "❌ POOR"
|
|
|
|
return {
|
|
'status': status,
|
|
'avg_improvement_pct': avg_improvement,
|
|
'positive_improvements': positive_improvements,
|
|
'total_metrics': len(improvements),
|
|
'key_improvements': {metric: comparison.get(metric, 0) for metric in key_metrics if metric in comparison}
|
|
}
|
|
|
|
def print_comparison_results(self, results: Dict[str, Any], output_file: Optional[str] = None):
|
|
"""Print formatted comparison results"""
|
|
model_type = results['model_type'].upper()
|
|
summary = results['summary']
|
|
|
|
output = []
|
|
output.append(f"\n{'='*80}")
|
|
output.append(f"📊 {model_type} MODEL COMPARISON RESULTS")
|
|
output.append(f"{'='*80}")
|
|
|
|
output.append(f"\n🎯 OVERALL ASSESSMENT: {summary['status']}")
|
|
output.append(f"📈 Average Improvement: {summary['avg_improvement_pct']:.2f}%")
|
|
output.append(f"✅ Metrics Improved: {summary['positive_improvements']}/{summary['total_metrics']}")
|
|
|
|
# Key metrics comparison
|
|
output.append(f"\n📋 KEY METRICS COMPARISON:")
|
|
output.append(f"{'Metric':<20} {'Baseline':<12} {'Fine-tuned':<12} {'Improvement':<12}")
|
|
output.append("-" * 60)
|
|
|
|
for metric, improvement in summary['key_improvements'].items():
|
|
baseline_val = results['baseline']['metrics'].get(metric, 0)
|
|
finetuned_val = results['finetuned']['metrics'].get(metric, 0)
|
|
|
|
status_icon = "📈" if improvement > 0 else "📉" if improvement < 0 else "➡️"
|
|
output.append(f"{metric:<20} {baseline_val:<12.4f} {finetuned_val:<12.4f} {status_icon} {improvement:>+7.2f}%")
|
|
|
|
# Performance comparison
|
|
output.append(f"\n⚡ PERFORMANCE COMPARISON:")
|
|
baseline_perf = results['baseline']['performance']
|
|
finetuned_perf = results['finetuned']['performance']
|
|
|
|
output.append(f"{'Metric':<25} {'Baseline':<15} {'Fine-tuned':<15} {'Change':<15}")
|
|
output.append("-" * 75)
|
|
|
|
perf_metrics = [
|
|
('Throughput (samples/s)', 'throughput_samples_per_sec'),
|
|
('Latency (ms/sample)', 'latency_ms_per_sample'),
|
|
('Memory Usage (MB)', 'memory_usage_mb')
|
|
]
|
|
|
|
for display_name, key in perf_metrics:
|
|
baseline_val = baseline_perf.get(key, 0)
|
|
finetuned_val = finetuned_perf.get(key, 0)
|
|
|
|
if baseline_val == 0:
|
|
change = 0
|
|
else:
|
|
change = ((finetuned_val - baseline_val) / baseline_val) * 100
|
|
|
|
change_icon = "⬆️" if change > 0 else "⬇️" if change < 0 else "➡️"
|
|
output.append(f"{display_name:<25} {baseline_val:<15.2f} {finetuned_val:<15.2f} {change_icon} {change:>+7.2f}%")
|
|
|
|
# Recommendations
|
|
output.append(f"\n💡 RECOMMENDATIONS:")
|
|
if summary['avg_improvement_pct'] > 5:
|
|
output.append(" • Excellent improvement! Model is ready for production.")
|
|
elif summary['avg_improvement_pct'] > 2:
|
|
output.append(" • Good improvement. Consider additional training or data.")
|
|
elif summary['avg_improvement_pct'] > 0:
|
|
output.append(" • Modest improvement. Review training hyperparameters.")
|
|
else:
|
|
output.append(" • Poor improvement. Review data quality and training strategy.")
|
|
|
|
if summary['positive_improvements'] < summary['total_metrics'] // 2:
|
|
output.append(" • Some metrics degraded. Consider ensemble or different approach.")
|
|
|
|
output.append(f"\n{'='*80}")
|
|
|
|
# Print to console
|
|
for line in output:
|
|
print(line)
|
|
|
|
# Save to file if specified
|
|
if output_file:
|
|
with open(output_file, 'w', encoding='utf-8') as f:
|
|
f.write('\n'.join(output))
|
|
logger.info(f"Results saved to: {output_file}")
|
|
|
|
|
|
def parse_args():
|
|
parser = argparse.ArgumentParser(description="Compare BGE models against baselines")
|
|
|
|
# Model type
|
|
parser.add_argument("--model_type", type=str, choices=['retriever', 'reranker', 'both'],
|
|
required=True, help="Type of model to compare")
|
|
|
|
# Single model comparison arguments
|
|
parser.add_argument("--finetuned_model", type=str, help="Path to fine-tuned model")
|
|
parser.add_argument("--baseline_model", type=str, help="Path to baseline model")
|
|
parser.add_argument("--data_path", type=str, help="Path to test data")
|
|
|
|
# Both models comparison arguments
|
|
parser.add_argument("--finetuned_retriever", type=str, help="Path to fine-tuned retriever")
|
|
parser.add_argument("--finetuned_reranker", type=str, help="Path to fine-tuned reranker")
|
|
parser.add_argument("--baseline_retriever", type=str, default="BAAI/bge-m3",
|
|
help="Baseline retriever model")
|
|
parser.add_argument("--baseline_reranker", type=str, default="BAAI/bge-reranker-base",
|
|
help="Baseline reranker model")
|
|
parser.add_argument("--retriever_data", type=str, help="Path to retriever test data")
|
|
parser.add_argument("--reranker_data", type=str, help="Path to reranker test data")
|
|
|
|
# Common arguments
|
|
parser.add_argument("--batch_size", type=int, default=16, help="Batch size for evaluation")
|
|
parser.add_argument("--max_samples", type=int, help="Maximum samples to test (for speed)")
|
|
parser.add_argument("--k_values", type=int, nargs='+', default=[1, 5, 10, 20, 100],
|
|
help="K values for metrics@k")
|
|
parser.add_argument("--device", type=str, default='auto', help="Device to use")
|
|
parser.add_argument("--output_dir", type=str, default="./comparison_results",
|
|
help="Directory to save results")
|
|
|
|
return parser.parse_args()
|
|
|
|
|
|
def main():
|
|
args = parse_args() # Changed from parser.parse_args() to parse_args()
|
|
|
|
# Create output directory
|
|
Path(args.output_dir).mkdir(parents=True, exist_ok=True)
|
|
|
|
# Initialize comparator
|
|
comparator = ModelComparator(device=args.device)
|
|
|
|
logger.info("🚀 Starting BGE Model Comparison")
|
|
logger.info(f"Model type: {args.model_type}")
|
|
|
|
if args.model_type == 'retriever':
|
|
if not all([args.finetuned_model, args.baseline_model, args.data_path]):
|
|
logger.error("For retriever comparison, provide --finetuned_model, --baseline_model, and --data_path")
|
|
return 1
|
|
|
|
results = comparator.compare_retrievers(
|
|
args.finetuned_model, args.baseline_model, args.data_path,
|
|
batch_size=args.batch_size, max_samples=args.max_samples, k_values=args.k_values
|
|
)
|
|
|
|
output_file = Path(args.output_dir) / "retriever_comparison.txt"
|
|
comparator.print_comparison_results(results, str(output_file))
|
|
|
|
# Save JSON results
|
|
with open(Path(args.output_dir) / "retriever_comparison.json", 'w') as f:
|
|
json.dump(results, f, indent=2, default=str)
|
|
|
|
elif args.model_type == 'reranker':
|
|
if not all([args.finetuned_model, args.baseline_model, args.data_path]):
|
|
logger.error("For reranker comparison, provide --finetuned_model, --baseline_model, and --data_path")
|
|
return 1
|
|
|
|
results = comparator.compare_rerankers(
|
|
args.finetuned_model, args.baseline_model, args.data_path,
|
|
batch_size=args.batch_size, max_samples=args.max_samples,
|
|
k_values=[k for k in args.k_values if k <= 10] # Reranker typically uses smaller K
|
|
)
|
|
|
|
output_file = Path(args.output_dir) / "reranker_comparison.txt"
|
|
comparator.print_comparison_results(results, str(output_file))
|
|
|
|
# Save JSON results
|
|
with open(Path(args.output_dir) / "reranker_comparison.json", 'w') as f:
|
|
json.dump(results, f, indent=2, default=str)
|
|
|
|
elif args.model_type == 'both':
|
|
if not all([args.finetuned_retriever, args.finetuned_reranker,
|
|
args.retriever_data, args.reranker_data]):
|
|
logger.error("For both comparison, provide all fine-tuned models and data paths")
|
|
return 1
|
|
|
|
# Compare retriever
|
|
logger.info("\n" + "="*50)
|
|
logger.info("RETRIEVER COMPARISON")
|
|
logger.info("="*50)
|
|
|
|
retriever_results = comparator.compare_retrievers(
|
|
args.finetuned_retriever, args.baseline_retriever, args.retriever_data,
|
|
batch_size=args.batch_size, max_samples=args.max_samples, k_values=args.k_values
|
|
)
|
|
|
|
# Compare reranker
|
|
logger.info("\n" + "="*50)
|
|
logger.info("RERANKER COMPARISON")
|
|
logger.info("="*50)
|
|
|
|
reranker_results = comparator.compare_rerankers(
|
|
args.finetuned_reranker, args.baseline_reranker, args.reranker_data,
|
|
batch_size=args.batch_size, max_samples=args.max_samples,
|
|
k_values=[k for k in args.k_values if k <= 10]
|
|
)
|
|
|
|
# Print results
|
|
comparator.print_comparison_results(
|
|
retriever_results, str(Path(args.output_dir) / "retriever_comparison.txt")
|
|
)
|
|
comparator.print_comparison_results(
|
|
reranker_results, str(Path(args.output_dir) / "reranker_comparison.txt")
|
|
)
|
|
|
|
# Save combined results
|
|
combined_results = {
|
|
'retriever': retriever_results,
|
|
'reranker': reranker_results,
|
|
'overall_summary': {
|
|
'retriever_status': retriever_results['summary']['status'],
|
|
'reranker_status': reranker_results['summary']['status'],
|
|
'avg_retriever_improvement': retriever_results['summary']['avg_improvement_pct'],
|
|
'avg_reranker_improvement': reranker_results['summary']['avg_improvement_pct']
|
|
}
|
|
}
|
|
|
|
with open(Path(args.output_dir) / "combined_comparison.json", 'w') as f:
|
|
json.dump(combined_results, f, indent=2, default=str)
|
|
|
|
logger.info(f"✅ Comparison completed! Results saved to: {args.output_dir}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main() |