Revised for training
This commit is contained in:
611
scripts/compare_models.py
Normal file
611
scripts/compare_models.py
Normal file
@@ -0,0 +1,611 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Unified BGE Model Comparison Script
|
||||
|
||||
This script compares fine-tuned BGE models (retriever/reranker) against baseline models,
|
||||
providing both accuracy metrics and performance comparisons.
|
||||
|
||||
Usage:
|
||||
# Compare retriever models
|
||||
python scripts/compare_models.py \
|
||||
--model_type retriever \
|
||||
--finetuned_model ./output/bge_m3_三国演义/final_model \
|
||||
--baseline_model BAAI/bge-m3 \
|
||||
--data_path data/datasets/三国演义/splits/m3_test.jsonl
|
||||
|
||||
# Compare reranker models
|
||||
python scripts/compare_models.py \
|
||||
--model_type reranker \
|
||||
--finetuned_model ./output/bge_reranker_三国演义/final_model \
|
||||
--baseline_model BAAI/bge-reranker-base \
|
||||
--data_path data/datasets/三国演义/splits/reranker_test.jsonl
|
||||
|
||||
# Compare both with comprehensive output
|
||||
python scripts/compare_models.py \
|
||||
--model_type both \
|
||||
--finetuned_retriever ./output/bge_m3_三国演义/final_model \
|
||||
--finetuned_reranker ./output/bge_reranker_三国演义/final_model \
|
||||
--baseline_retriever BAAI/bge-m3 \
|
||||
--baseline_reranker BAAI/bge-reranker-base \
|
||||
--retriever_data data/datasets/三国演义/splits/m3_test.jsonl \
|
||||
--reranker_data data/datasets/三国演义/splits/reranker_test.jsonl
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import argparse
|
||||
import logging
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Dict, Tuple, Optional, Any
|
||||
import torch
|
||||
import psutil
|
||||
import numpy as np
|
||||
from transformers import AutoTokenizer
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
# Add project root to path
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from models.bge_m3 import BGEM3Model
|
||||
from models.bge_reranker import BGERerankerModel
|
||||
from data.dataset import BGEM3Dataset, BGERerankerDataset
|
||||
from evaluation.metrics import compute_retrieval_metrics, compute_reranker_metrics
|
||||
from utils.config_loader import get_config
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ModelComparator:
|
||||
"""Unified model comparison class for both retriever and reranker models"""
|
||||
|
||||
def __init__(self, device: str = 'auto'):
|
||||
self.device = self._get_device(device)
|
||||
|
||||
def _get_device(self, device: str) -> torch.device:
|
||||
"""Get optimal device"""
|
||||
if device == 'auto':
|
||||
try:
|
||||
config = get_config()
|
||||
device_type = config.get('hardware', {}).get('device', 'cuda')
|
||||
|
||||
if device_type == 'npu' and torch.cuda.is_available():
|
||||
try:
|
||||
import torch_npu
|
||||
return torch.device("npu:0")
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
if torch.cuda.is_available():
|
||||
return torch.device("cuda")
|
||||
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return torch.device("cpu")
|
||||
else:
|
||||
return torch.device(device)
|
||||
|
||||
def measure_memory(self) -> float:
|
||||
"""Measure current memory usage in MB"""
|
||||
process = psutil.Process(os.getpid())
|
||||
return process.memory_info().rss / (1024 * 1024)
|
||||
|
||||
def compare_retrievers(
|
||||
self,
|
||||
finetuned_path: str,
|
||||
baseline_path: str,
|
||||
data_path: str,
|
||||
batch_size: int = 16,
|
||||
max_samples: Optional[int] = None,
|
||||
k_values: list = [1, 5, 10, 20, 100]
|
||||
) -> Dict[str, Any]:
|
||||
"""Compare retriever models"""
|
||||
logger.info("🔍 Comparing Retriever Models")
|
||||
logger.info(f"Fine-tuned: {finetuned_path}")
|
||||
logger.info(f"Baseline: {baseline_path}")
|
||||
|
||||
# Test both models
|
||||
ft_results = self._benchmark_retriever(finetuned_path, data_path, batch_size, max_samples, k_values)
|
||||
baseline_results = self._benchmark_retriever(baseline_path, data_path, batch_size, max_samples, k_values)
|
||||
|
||||
# Compute improvements
|
||||
comparison = self._compute_metric_improvements(baseline_results['metrics'], ft_results['metrics'])
|
||||
|
||||
return {
|
||||
'model_type': 'retriever',
|
||||
'finetuned': ft_results,
|
||||
'baseline': baseline_results,
|
||||
'comparison': comparison,
|
||||
'summary': self._generate_retriever_summary(comparison)
|
||||
}
|
||||
|
||||
def compare_rerankers(
|
||||
self,
|
||||
finetuned_path: str,
|
||||
baseline_path: str,
|
||||
data_path: str,
|
||||
batch_size: int = 32,
|
||||
max_samples: Optional[int] = None,
|
||||
k_values: list = [1, 5, 10]
|
||||
) -> Dict[str, Any]:
|
||||
"""Compare reranker models"""
|
||||
logger.info("🏆 Comparing Reranker Models")
|
||||
logger.info(f"Fine-tuned: {finetuned_path}")
|
||||
logger.info(f"Baseline: {baseline_path}")
|
||||
|
||||
# Test both models
|
||||
ft_results = self._benchmark_reranker(finetuned_path, data_path, batch_size, max_samples, k_values)
|
||||
baseline_results = self._benchmark_reranker(baseline_path, data_path, batch_size, max_samples, k_values)
|
||||
|
||||
# Compute improvements
|
||||
comparison = self._compute_metric_improvements(baseline_results['metrics'], ft_results['metrics'])
|
||||
|
||||
return {
|
||||
'model_type': 'reranker',
|
||||
'finetuned': ft_results,
|
||||
'baseline': baseline_results,
|
||||
'comparison': comparison,
|
||||
'summary': self._generate_reranker_summary(comparison)
|
||||
}
|
||||
|
||||
def _benchmark_retriever(
|
||||
self,
|
||||
model_path: str,
|
||||
data_path: str,
|
||||
batch_size: int,
|
||||
max_samples: Optional[int],
|
||||
k_values: list
|
||||
) -> Dict[str, Any]:
|
||||
"""Benchmark a retriever model"""
|
||||
logger.info(f"Benchmarking retriever: {model_path}")
|
||||
|
||||
# Load model
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
||||
model = BGEM3Model.from_pretrained(model_path)
|
||||
model.to(self.device)
|
||||
model.eval()
|
||||
|
||||
# Load data
|
||||
dataset = BGEM3Dataset(
|
||||
data_path=data_path,
|
||||
tokenizer=tokenizer,
|
||||
is_train=False
|
||||
)
|
||||
|
||||
if max_samples and len(dataset) > max_samples:
|
||||
dataset.data = dataset.data[:max_samples]
|
||||
|
||||
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
|
||||
logger.info(f"Testing on {len(dataset)} samples")
|
||||
|
||||
# Benchmark
|
||||
start_memory = self.measure_memory()
|
||||
start_time = time.time()
|
||||
|
||||
all_query_embeddings = []
|
||||
all_passage_embeddings = []
|
||||
all_labels = []
|
||||
|
||||
with torch.no_grad():
|
||||
for batch in dataloader:
|
||||
# Move to device
|
||||
batch = {k: v.to(self.device) if isinstance(v, torch.Tensor) else v
|
||||
for k, v in batch.items()}
|
||||
|
||||
# Get embeddings
|
||||
query_outputs = model.forward(
|
||||
input_ids=batch['query_input_ids'],
|
||||
attention_mask=batch['query_attention_mask'],
|
||||
return_dense=True,
|
||||
return_dict=True
|
||||
)
|
||||
|
||||
passage_outputs = model.forward(
|
||||
input_ids=batch['passage_input_ids'],
|
||||
attention_mask=batch['passage_attention_mask'],
|
||||
return_dense=True,
|
||||
return_dict=True
|
||||
)
|
||||
|
||||
all_query_embeddings.append(query_outputs['dense'].cpu().numpy())
|
||||
all_passage_embeddings.append(passage_outputs['dense'].cpu().numpy())
|
||||
all_labels.append(batch['labels'].cpu().numpy())
|
||||
|
||||
end_time = time.time()
|
||||
peak_memory = self.measure_memory()
|
||||
|
||||
# Compute metrics
|
||||
query_embeds = np.concatenate(all_query_embeddings, axis=0)
|
||||
passage_embeds = np.concatenate(all_passage_embeddings, axis=0)
|
||||
labels = np.concatenate(all_labels, axis=0)
|
||||
|
||||
metrics = compute_retrieval_metrics(
|
||||
query_embeds, passage_embeds, labels, k_values=k_values
|
||||
)
|
||||
|
||||
# Performance stats
|
||||
total_time = end_time - start_time
|
||||
throughput = len(dataset) / total_time
|
||||
latency = (total_time * 1000) / len(dataset) # ms per sample
|
||||
memory_usage = peak_memory - start_memory
|
||||
|
||||
return {
|
||||
'model_path': model_path,
|
||||
'metrics': metrics,
|
||||
'performance': {
|
||||
'throughput_samples_per_sec': throughput,
|
||||
'latency_ms_per_sample': latency,
|
||||
'memory_usage_mb': memory_usage,
|
||||
'total_time_sec': total_time,
|
||||
'samples_processed': len(dataset)
|
||||
}
|
||||
}
|
||||
|
||||
def _benchmark_reranker(
|
||||
self,
|
||||
model_path: str,
|
||||
data_path: str,
|
||||
batch_size: int,
|
||||
max_samples: Optional[int],
|
||||
k_values: list
|
||||
) -> Dict[str, Any]:
|
||||
"""Benchmark a reranker model"""
|
||||
logger.info(f"Benchmarking reranker: {model_path}")
|
||||
|
||||
# Load model
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
||||
model = BGERerankerModel.from_pretrained(model_path)
|
||||
model.to(self.device)
|
||||
model.eval()
|
||||
|
||||
# Load data
|
||||
dataset = BGERerankerDataset(
|
||||
data_path=data_path,
|
||||
tokenizer=tokenizer,
|
||||
is_train=False
|
||||
)
|
||||
|
||||
if max_samples and len(dataset) > max_samples:
|
||||
dataset.data = dataset.data[:max_samples]
|
||||
|
||||
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
|
||||
logger.info(f"Testing on {len(dataset)} samples")
|
||||
|
||||
# Benchmark
|
||||
start_memory = self.measure_memory()
|
||||
start_time = time.time()
|
||||
|
||||
all_scores = []
|
||||
all_labels = []
|
||||
|
||||
with torch.no_grad():
|
||||
for batch in dataloader:
|
||||
# Move to device
|
||||
batch = {k: v.to(self.device) if isinstance(v, torch.Tensor) else v
|
||||
for k, v in batch.items()}
|
||||
|
||||
# Get scores
|
||||
outputs = model.forward(
|
||||
input_ids=batch['input_ids'],
|
||||
attention_mask=batch['attention_mask'],
|
||||
return_dict=True
|
||||
)
|
||||
|
||||
logits = outputs.logits if hasattr(outputs, 'logits') else outputs.scores
|
||||
scores = torch.sigmoid(logits.squeeze()).cpu().numpy()
|
||||
|
||||
all_scores.append(scores)
|
||||
all_labels.append(batch['labels'].cpu().numpy())
|
||||
|
||||
end_time = time.time()
|
||||
peak_memory = self.measure_memory()
|
||||
|
||||
# Compute metrics
|
||||
scores = np.concatenate(all_scores, axis=0)
|
||||
labels = np.concatenate(all_labels, axis=0)
|
||||
|
||||
metrics = compute_reranker_metrics(scores, labels, k_values=k_values)
|
||||
|
||||
# Performance stats
|
||||
total_time = end_time - start_time
|
||||
throughput = len(dataset) / total_time
|
||||
latency = (total_time * 1000) / len(dataset) # ms per sample
|
||||
memory_usage = peak_memory - start_memory
|
||||
|
||||
return {
|
||||
'model_path': model_path,
|
||||
'metrics': metrics,
|
||||
'performance': {
|
||||
'throughput_samples_per_sec': throughput,
|
||||
'latency_ms_per_sample': latency,
|
||||
'memory_usage_mb': memory_usage,
|
||||
'total_time_sec': total_time,
|
||||
'samples_processed': len(dataset)
|
||||
}
|
||||
}
|
||||
|
||||
def _compute_metric_improvements(self, baseline_metrics: Dict, finetuned_metrics: Dict) -> Dict[str, float]:
|
||||
"""Compute percentage improvements for each metric"""
|
||||
improvements = {}
|
||||
|
||||
for metric_name in baseline_metrics:
|
||||
baseline_val = baseline_metrics[metric_name]
|
||||
finetuned_val = finetuned_metrics.get(metric_name, baseline_val)
|
||||
|
||||
if baseline_val == 0:
|
||||
improvement = 0.0 if finetuned_val == 0 else float('inf')
|
||||
else:
|
||||
improvement = ((finetuned_val - baseline_val) / baseline_val) * 100
|
||||
|
||||
improvements[metric_name] = improvement
|
||||
|
||||
return improvements
|
||||
|
||||
def _generate_retriever_summary(self, comparison: Dict[str, float]) -> Dict[str, Any]:
|
||||
"""Generate summary for retriever comparison"""
|
||||
key_metrics = ['recall@5', 'recall@10', 'map', 'mrr']
|
||||
improvements = [comparison.get(metric, 0) for metric in key_metrics if metric in comparison]
|
||||
|
||||
avg_improvement = np.mean(improvements) if improvements else 0
|
||||
positive_improvements = sum(1 for imp in improvements if imp > 0)
|
||||
|
||||
if avg_improvement > 5:
|
||||
status = "🌟 EXCELLENT"
|
||||
elif avg_improvement > 2:
|
||||
status = "✅ GOOD"
|
||||
elif avg_improvement > 0:
|
||||
status = "👌 FAIR"
|
||||
else:
|
||||
status = "❌ POOR"
|
||||
|
||||
return {
|
||||
'status': status,
|
||||
'avg_improvement_pct': avg_improvement,
|
||||
'positive_improvements': positive_improvements,
|
||||
'total_metrics': len(improvements),
|
||||
'key_improvements': {metric: comparison.get(metric, 0) for metric in key_metrics if metric in comparison}
|
||||
}
|
||||
|
||||
def _generate_reranker_summary(self, comparison: Dict[str, float]) -> Dict[str, Any]:
|
||||
"""Generate summary for reranker comparison"""
|
||||
key_metrics = ['accuracy', 'precision', 'recall', 'f1', 'auc']
|
||||
improvements = [comparison.get(metric, 0) for metric in key_metrics if metric in comparison]
|
||||
|
||||
avg_improvement = np.mean(improvements) if improvements else 0
|
||||
positive_improvements = sum(1 for imp in improvements if imp > 0)
|
||||
|
||||
if avg_improvement > 3:
|
||||
status = "🌟 EXCELLENT"
|
||||
elif avg_improvement > 1:
|
||||
status = "✅ GOOD"
|
||||
elif avg_improvement > 0:
|
||||
status = "👌 FAIR"
|
||||
else:
|
||||
status = "❌ POOR"
|
||||
|
||||
return {
|
||||
'status': status,
|
||||
'avg_improvement_pct': avg_improvement,
|
||||
'positive_improvements': positive_improvements,
|
||||
'total_metrics': len(improvements),
|
||||
'key_improvements': {metric: comparison.get(metric, 0) for metric in key_metrics if metric in comparison}
|
||||
}
|
||||
|
||||
def print_comparison_results(self, results: Dict[str, Any], output_file: Optional[str] = None):
|
||||
"""Print formatted comparison results"""
|
||||
model_type = results['model_type'].upper()
|
||||
summary = results['summary']
|
||||
|
||||
output = []
|
||||
output.append(f"\n{'='*80}")
|
||||
output.append(f"📊 {model_type} MODEL COMPARISON RESULTS")
|
||||
output.append(f"{'='*80}")
|
||||
|
||||
output.append(f"\n🎯 OVERALL ASSESSMENT: {summary['status']}")
|
||||
output.append(f"📈 Average Improvement: {summary['avg_improvement_pct']:.2f}%")
|
||||
output.append(f"✅ Metrics Improved: {summary['positive_improvements']}/{summary['total_metrics']}")
|
||||
|
||||
# Key metrics comparison
|
||||
output.append(f"\n📋 KEY METRICS COMPARISON:")
|
||||
output.append(f"{'Metric':<20} {'Baseline':<12} {'Fine-tuned':<12} {'Improvement':<12}")
|
||||
output.append("-" * 60)
|
||||
|
||||
for metric, improvement in summary['key_improvements'].items():
|
||||
baseline_val = results['baseline']['metrics'].get(metric, 0)
|
||||
finetuned_val = results['finetuned']['metrics'].get(metric, 0)
|
||||
|
||||
status_icon = "📈" if improvement > 0 else "📉" if improvement < 0 else "➡️"
|
||||
output.append(f"{metric:<20} {baseline_val:<12.4f} {finetuned_val:<12.4f} {status_icon} {improvement:>+7.2f}%")
|
||||
|
||||
# Performance comparison
|
||||
output.append(f"\n⚡ PERFORMANCE COMPARISON:")
|
||||
baseline_perf = results['baseline']['performance']
|
||||
finetuned_perf = results['finetuned']['performance']
|
||||
|
||||
output.append(f"{'Metric':<25} {'Baseline':<15} {'Fine-tuned':<15} {'Change':<15}")
|
||||
output.append("-" * 75)
|
||||
|
||||
perf_metrics = [
|
||||
('Throughput (samples/s)', 'throughput_samples_per_sec'),
|
||||
('Latency (ms/sample)', 'latency_ms_per_sample'),
|
||||
('Memory Usage (MB)', 'memory_usage_mb')
|
||||
]
|
||||
|
||||
for display_name, key in perf_metrics:
|
||||
baseline_val = baseline_perf.get(key, 0)
|
||||
finetuned_val = finetuned_perf.get(key, 0)
|
||||
|
||||
if baseline_val == 0:
|
||||
change = 0
|
||||
else:
|
||||
change = ((finetuned_val - baseline_val) / baseline_val) * 100
|
||||
|
||||
change_icon = "⬆️" if change > 0 else "⬇️" if change < 0 else "➡️"
|
||||
output.append(f"{display_name:<25} {baseline_val:<15.2f} {finetuned_val:<15.2f} {change_icon} {change:>+7.2f}%")
|
||||
|
||||
# Recommendations
|
||||
output.append(f"\n💡 RECOMMENDATIONS:")
|
||||
if summary['avg_improvement_pct'] > 5:
|
||||
output.append(" • Excellent improvement! Model is ready for production.")
|
||||
elif summary['avg_improvement_pct'] > 2:
|
||||
output.append(" • Good improvement. Consider additional training or data.")
|
||||
elif summary['avg_improvement_pct'] > 0:
|
||||
output.append(" • Modest improvement. Review training hyperparameters.")
|
||||
else:
|
||||
output.append(" • Poor improvement. Review data quality and training strategy.")
|
||||
|
||||
if summary['positive_improvements'] < summary['total_metrics'] // 2:
|
||||
output.append(" • Some metrics degraded. Consider ensemble or different approach.")
|
||||
|
||||
output.append(f"\n{'='*80}")
|
||||
|
||||
# Print to console
|
||||
for line in output:
|
||||
print(line)
|
||||
|
||||
# Save to file if specified
|
||||
if output_file:
|
||||
with open(output_file, 'w', encoding='utf-8') as f:
|
||||
f.write('\n'.join(output))
|
||||
logger.info(f"Results saved to: {output_file}")
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description="Compare BGE models against baselines")
|
||||
|
||||
# Model type
|
||||
parser.add_argument("--model_type", type=str, choices=['retriever', 'reranker', 'both'],
|
||||
required=True, help="Type of model to compare")
|
||||
|
||||
# Single model comparison arguments
|
||||
parser.add_argument("--finetuned_model", type=str, help="Path to fine-tuned model")
|
||||
parser.add_argument("--baseline_model", type=str, help="Path to baseline model")
|
||||
parser.add_argument("--data_path", type=str, help="Path to test data")
|
||||
|
||||
# Both models comparison arguments
|
||||
parser.add_argument("--finetuned_retriever", type=str, help="Path to fine-tuned retriever")
|
||||
parser.add_argument("--finetuned_reranker", type=str, help="Path to fine-tuned reranker")
|
||||
parser.add_argument("--baseline_retriever", type=str, default="BAAI/bge-m3",
|
||||
help="Baseline retriever model")
|
||||
parser.add_argument("--baseline_reranker", type=str, default="BAAI/bge-reranker-base",
|
||||
help="Baseline reranker model")
|
||||
parser.add_argument("--retriever_data", type=str, help="Path to retriever test data")
|
||||
parser.add_argument("--reranker_data", type=str, help="Path to reranker test data")
|
||||
|
||||
# Common arguments
|
||||
parser.add_argument("--batch_size", type=int, default=16, help="Batch size for evaluation")
|
||||
parser.add_argument("--max_samples", type=int, help="Maximum samples to test (for speed)")
|
||||
parser.add_argument("--k_values", type=int, nargs='+', default=[1, 5, 10, 20, 100],
|
||||
help="K values for metrics@k")
|
||||
parser.add_argument("--device", type=str, default='auto', help="Device to use")
|
||||
parser.add_argument("--output_dir", type=str, default="./comparison_results",
|
||||
help="Directory to save results")
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args() # Changed from parser.parse_args() to parse_args()
|
||||
|
||||
# Create output directory
|
||||
Path(args.output_dir).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Initialize comparator
|
||||
comparator = ModelComparator(device=args.device)
|
||||
|
||||
logger.info("🚀 Starting BGE Model Comparison")
|
||||
logger.info(f"Model type: {args.model_type}")
|
||||
|
||||
if args.model_type == 'retriever':
|
||||
if not all([args.finetuned_model, args.baseline_model, args.data_path]):
|
||||
logger.error("For retriever comparison, provide --finetuned_model, --baseline_model, and --data_path")
|
||||
return 1
|
||||
|
||||
results = comparator.compare_retrievers(
|
||||
args.finetuned_model, args.baseline_model, args.data_path,
|
||||
batch_size=args.batch_size, max_samples=args.max_samples, k_values=args.k_values
|
||||
)
|
||||
|
||||
output_file = Path(args.output_dir) / "retriever_comparison.txt"
|
||||
comparator.print_comparison_results(results, str(output_file))
|
||||
|
||||
# Save JSON results
|
||||
with open(Path(args.output_dir) / "retriever_comparison.json", 'w') as f:
|
||||
json.dump(results, f, indent=2, default=str)
|
||||
|
||||
elif args.model_type == 'reranker':
|
||||
if not all([args.finetuned_model, args.baseline_model, args.data_path]):
|
||||
logger.error("For reranker comparison, provide --finetuned_model, --baseline_model, and --data_path")
|
||||
return 1
|
||||
|
||||
results = comparator.compare_rerankers(
|
||||
args.finetuned_model, args.baseline_model, args.data_path,
|
||||
batch_size=args.batch_size, max_samples=args.max_samples,
|
||||
k_values=[k for k in args.k_values if k <= 10] # Reranker typically uses smaller K
|
||||
)
|
||||
|
||||
output_file = Path(args.output_dir) / "reranker_comparison.txt"
|
||||
comparator.print_comparison_results(results, str(output_file))
|
||||
|
||||
# Save JSON results
|
||||
with open(Path(args.output_dir) / "reranker_comparison.json", 'w') as f:
|
||||
json.dump(results, f, indent=2, default=str)
|
||||
|
||||
elif args.model_type == 'both':
|
||||
if not all([args.finetuned_retriever, args.finetuned_reranker,
|
||||
args.retriever_data, args.reranker_data]):
|
||||
logger.error("For both comparison, provide all fine-tuned models and data paths")
|
||||
return 1
|
||||
|
||||
# Compare retriever
|
||||
logger.info("\n" + "="*50)
|
||||
logger.info("RETRIEVER COMPARISON")
|
||||
logger.info("="*50)
|
||||
|
||||
retriever_results = comparator.compare_retrievers(
|
||||
args.finetuned_retriever, args.baseline_retriever, args.retriever_data,
|
||||
batch_size=args.batch_size, max_samples=args.max_samples, k_values=args.k_values
|
||||
)
|
||||
|
||||
# Compare reranker
|
||||
logger.info("\n" + "="*50)
|
||||
logger.info("RERANKER COMPARISON")
|
||||
logger.info("="*50)
|
||||
|
||||
reranker_results = comparator.compare_rerankers(
|
||||
args.finetuned_reranker, args.baseline_reranker, args.reranker_data,
|
||||
batch_size=args.batch_size, max_samples=args.max_samples,
|
||||
k_values=[k for k in args.k_values if k <= 10]
|
||||
)
|
||||
|
||||
# Print results
|
||||
comparator.print_comparison_results(
|
||||
retriever_results, str(Path(args.output_dir) / "retriever_comparison.txt")
|
||||
)
|
||||
comparator.print_comparison_results(
|
||||
reranker_results, str(Path(args.output_dir) / "reranker_comparison.txt")
|
||||
)
|
||||
|
||||
# Save combined results
|
||||
combined_results = {
|
||||
'retriever': retriever_results,
|
||||
'reranker': reranker_results,
|
||||
'overall_summary': {
|
||||
'retriever_status': retriever_results['summary']['status'],
|
||||
'reranker_status': reranker_results['summary']['status'],
|
||||
'avg_retriever_improvement': retriever_results['summary']['avg_improvement_pct'],
|
||||
'avg_reranker_improvement': reranker_results['summary']['avg_improvement_pct']
|
||||
}
|
||||
}
|
||||
|
||||
with open(Path(args.output_dir) / "combined_comparison.json", 'w') as f:
|
||||
json.dump(combined_results, f, indent=2, default=str)
|
||||
|
||||
logger.info(f"✅ Comparison completed! Results saved to: {args.output_dir}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user