bge_finetune/scripts/compare_models.py
2025-07-23 14:54:46 +08:00

611 lines
24 KiB
Python

#!/usr/bin/env python3
"""
Unified BGE Model Comparison Script
This script compares fine-tuned BGE models (retriever/reranker) against baseline models,
providing both accuracy metrics and performance comparisons.
Usage:
# Compare retriever models
python scripts/compare_models.py \
--model_type retriever \
--finetuned_model ./output/bge_m3_三国演义/final_model \
--baseline_model BAAI/bge-m3 \
--data_path data/datasets/三国演义/splits/m3_test.jsonl
# Compare reranker models
python scripts/compare_models.py \
--model_type reranker \
--finetuned_model ./output/bge_reranker_三国演义/final_model \
--baseline_model BAAI/bge-reranker-base \
--data_path data/datasets/三国演义/splits/reranker_test.jsonl
# Compare both with comprehensive output
python scripts/compare_models.py \
--model_type both \
--finetuned_retriever ./output/bge_m3_三国演义/final_model \
--finetuned_reranker ./output/bge_reranker_三国演义/final_model \
--baseline_retriever BAAI/bge-m3 \
--baseline_reranker BAAI/bge-reranker-base \
--retriever_data data/datasets/三国演义/splits/m3_test.jsonl \
--reranker_data data/datasets/三国演义/splits/reranker_test.jsonl
"""
import os
import sys
import time
import argparse
import logging
import json
from pathlib import Path
from typing import Dict, Tuple, Optional, Any
import torch
import psutil
import numpy as np
from transformers import AutoTokenizer
from torch.utils.data import DataLoader
# Add project root to path
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from models.bge_m3 import BGEM3Model
from models.bge_reranker import BGERerankerModel
from data.dataset import BGEM3Dataset, BGERerankerDataset
from evaluation.metrics import compute_retrieval_metrics, compute_reranker_metrics
from utils.config_loader import get_config
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
class ModelComparator:
"""Unified model comparison class for both retriever and reranker models"""
def __init__(self, device: str = 'auto'):
self.device = self._get_device(device)
def _get_device(self, device: str) -> torch.device:
"""Get optimal device"""
if device == 'auto':
try:
config = get_config()
device_type = config.get('hardware', {}).get('device', 'cuda')
if device_type == 'npu' and torch.cuda.is_available():
try:
import torch_npu
return torch.device("npu:0")
except ImportError:
pass
if torch.cuda.is_available():
return torch.device("cuda")
except Exception:
pass
return torch.device("cpu")
else:
return torch.device(device)
def measure_memory(self) -> float:
"""Measure current memory usage in MB"""
process = psutil.Process(os.getpid())
return process.memory_info().rss / (1024 * 1024)
def compare_retrievers(
self,
finetuned_path: str,
baseline_path: str,
data_path: str,
batch_size: int = 16,
max_samples: Optional[int] = None,
k_values: list = [1, 5, 10, 20, 100]
) -> Dict[str, Any]:
"""Compare retriever models"""
logger.info("🔍 Comparing Retriever Models")
logger.info(f"Fine-tuned: {finetuned_path}")
logger.info(f"Baseline: {baseline_path}")
# Test both models
ft_results = self._benchmark_retriever(finetuned_path, data_path, batch_size, max_samples, k_values)
baseline_results = self._benchmark_retriever(baseline_path, data_path, batch_size, max_samples, k_values)
# Compute improvements
comparison = self._compute_metric_improvements(baseline_results['metrics'], ft_results['metrics'])
return {
'model_type': 'retriever',
'finetuned': ft_results,
'baseline': baseline_results,
'comparison': comparison,
'summary': self._generate_retriever_summary(comparison)
}
def compare_rerankers(
self,
finetuned_path: str,
baseline_path: str,
data_path: str,
batch_size: int = 32,
max_samples: Optional[int] = None,
k_values: list = [1, 5, 10]
) -> Dict[str, Any]:
"""Compare reranker models"""
logger.info("🏆 Comparing Reranker Models")
logger.info(f"Fine-tuned: {finetuned_path}")
logger.info(f"Baseline: {baseline_path}")
# Test both models
ft_results = self._benchmark_reranker(finetuned_path, data_path, batch_size, max_samples, k_values)
baseline_results = self._benchmark_reranker(baseline_path, data_path, batch_size, max_samples, k_values)
# Compute improvements
comparison = self._compute_metric_improvements(baseline_results['metrics'], ft_results['metrics'])
return {
'model_type': 'reranker',
'finetuned': ft_results,
'baseline': baseline_results,
'comparison': comparison,
'summary': self._generate_reranker_summary(comparison)
}
def _benchmark_retriever(
self,
model_path: str,
data_path: str,
batch_size: int,
max_samples: Optional[int],
k_values: list
) -> Dict[str, Any]:
"""Benchmark a retriever model"""
logger.info(f"Benchmarking retriever: {model_path}")
# Load model
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
model = BGEM3Model.from_pretrained(model_path)
model.to(self.device)
model.eval()
# Load data
dataset = BGEM3Dataset(
data_path=data_path,
tokenizer=tokenizer,
is_train=False
)
if max_samples and len(dataset) > max_samples:
dataset.data = dataset.data[:max_samples]
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
logger.info(f"Testing on {len(dataset)} samples")
# Benchmark
start_memory = self.measure_memory()
start_time = time.time()
all_query_embeddings = []
all_passage_embeddings = []
all_labels = []
with torch.no_grad():
for batch in dataloader:
# Move to device
batch = {k: v.to(self.device) if isinstance(v, torch.Tensor) else v
for k, v in batch.items()}
# Get embeddings
query_outputs = model.forward(
input_ids=batch['query_input_ids'],
attention_mask=batch['query_attention_mask'],
return_dense=True,
return_dict=True
)
passage_outputs = model.forward(
input_ids=batch['passage_input_ids'],
attention_mask=batch['passage_attention_mask'],
return_dense=True,
return_dict=True
)
all_query_embeddings.append(query_outputs['dense'].cpu().numpy())
all_passage_embeddings.append(passage_outputs['dense'].cpu().numpy())
all_labels.append(batch['labels'].cpu().numpy())
end_time = time.time()
peak_memory = self.measure_memory()
# Compute metrics
query_embeds = np.concatenate(all_query_embeddings, axis=0)
passage_embeds = np.concatenate(all_passage_embeddings, axis=0)
labels = np.concatenate(all_labels, axis=0)
metrics = compute_retrieval_metrics(
query_embeds, passage_embeds, labels, k_values=k_values
)
# Performance stats
total_time = end_time - start_time
throughput = len(dataset) / total_time
latency = (total_time * 1000) / len(dataset) # ms per sample
memory_usage = peak_memory - start_memory
return {
'model_path': model_path,
'metrics': metrics,
'performance': {
'throughput_samples_per_sec': throughput,
'latency_ms_per_sample': latency,
'memory_usage_mb': memory_usage,
'total_time_sec': total_time,
'samples_processed': len(dataset)
}
}
def _benchmark_reranker(
self,
model_path: str,
data_path: str,
batch_size: int,
max_samples: Optional[int],
k_values: list
) -> Dict[str, Any]:
"""Benchmark a reranker model"""
logger.info(f"Benchmarking reranker: {model_path}")
# Load model
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
model = BGERerankerModel.from_pretrained(model_path)
model.to(self.device)
model.eval()
# Load data
dataset = BGERerankerDataset(
data_path=data_path,
tokenizer=tokenizer,
is_train=False
)
if max_samples and len(dataset) > max_samples:
dataset.data = dataset.data[:max_samples]
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
logger.info(f"Testing on {len(dataset)} samples")
# Benchmark
start_memory = self.measure_memory()
start_time = time.time()
all_scores = []
all_labels = []
with torch.no_grad():
for batch in dataloader:
# Move to device
batch = {k: v.to(self.device) if isinstance(v, torch.Tensor) else v
for k, v in batch.items()}
# Get scores
outputs = model.forward(
input_ids=batch['input_ids'],
attention_mask=batch['attention_mask'],
return_dict=True
)
logits = outputs.logits if hasattr(outputs, 'logits') else outputs.scores
scores = torch.sigmoid(logits.squeeze()).cpu().numpy()
all_scores.append(scores)
all_labels.append(batch['labels'].cpu().numpy())
end_time = time.time()
peak_memory = self.measure_memory()
# Compute metrics
scores = np.concatenate(all_scores, axis=0)
labels = np.concatenate(all_labels, axis=0)
metrics = compute_reranker_metrics(scores, labels, k_values=k_values)
# Performance stats
total_time = end_time - start_time
throughput = len(dataset) / total_time
latency = (total_time * 1000) / len(dataset) # ms per sample
memory_usage = peak_memory - start_memory
return {
'model_path': model_path,
'metrics': metrics,
'performance': {
'throughput_samples_per_sec': throughput,
'latency_ms_per_sample': latency,
'memory_usage_mb': memory_usage,
'total_time_sec': total_time,
'samples_processed': len(dataset)
}
}
def _compute_metric_improvements(self, baseline_metrics: Dict, finetuned_metrics: Dict) -> Dict[str, float]:
"""Compute percentage improvements for each metric"""
improvements = {}
for metric_name in baseline_metrics:
baseline_val = baseline_metrics[metric_name]
finetuned_val = finetuned_metrics.get(metric_name, baseline_val)
if baseline_val == 0:
improvement = 0.0 if finetuned_val == 0 else float('inf')
else:
improvement = ((finetuned_val - baseline_val) / baseline_val) * 100
improvements[metric_name] = improvement
return improvements
def _generate_retriever_summary(self, comparison: Dict[str, float]) -> Dict[str, Any]:
"""Generate summary for retriever comparison"""
key_metrics = ['recall@5', 'recall@10', 'map', 'mrr']
improvements = [comparison.get(metric, 0) for metric in key_metrics if metric in comparison]
avg_improvement = np.mean(improvements) if improvements else 0
positive_improvements = sum(1 for imp in improvements if imp > 0)
if avg_improvement > 5:
status = "🌟 EXCELLENT"
elif avg_improvement > 2:
status = "✅ GOOD"
elif avg_improvement > 0:
status = "👌 FAIR"
else:
status = "❌ POOR"
return {
'status': status,
'avg_improvement_pct': avg_improvement,
'positive_improvements': positive_improvements,
'total_metrics': len(improvements),
'key_improvements': {metric: comparison.get(metric, 0) for metric in key_metrics if metric in comparison}
}
def _generate_reranker_summary(self, comparison: Dict[str, float]) -> Dict[str, Any]:
"""Generate summary for reranker comparison"""
key_metrics = ['accuracy', 'precision', 'recall', 'f1', 'auc']
improvements = [comparison.get(metric, 0) for metric in key_metrics if metric in comparison]
avg_improvement = np.mean(improvements) if improvements else 0
positive_improvements = sum(1 for imp in improvements if imp > 0)
if avg_improvement > 3:
status = "🌟 EXCELLENT"
elif avg_improvement > 1:
status = "✅ GOOD"
elif avg_improvement > 0:
status = "👌 FAIR"
else:
status = "❌ POOR"
return {
'status': status,
'avg_improvement_pct': avg_improvement,
'positive_improvements': positive_improvements,
'total_metrics': len(improvements),
'key_improvements': {metric: comparison.get(metric, 0) for metric in key_metrics if metric in comparison}
}
def print_comparison_results(self, results: Dict[str, Any], output_file: Optional[str] = None):
"""Print formatted comparison results"""
model_type = results['model_type'].upper()
summary = results['summary']
output = []
output.append(f"\n{'='*80}")
output.append(f"📊 {model_type} MODEL COMPARISON RESULTS")
output.append(f"{'='*80}")
output.append(f"\n🎯 OVERALL ASSESSMENT: {summary['status']}")
output.append(f"📈 Average Improvement: {summary['avg_improvement_pct']:.2f}%")
output.append(f"✅ Metrics Improved: {summary['positive_improvements']}/{summary['total_metrics']}")
# Key metrics comparison
output.append(f"\n📋 KEY METRICS COMPARISON:")
output.append(f"{'Metric':<20} {'Baseline':<12} {'Fine-tuned':<12} {'Improvement':<12}")
output.append("-" * 60)
for metric, improvement in summary['key_improvements'].items():
baseline_val = results['baseline']['metrics'].get(metric, 0)
finetuned_val = results['finetuned']['metrics'].get(metric, 0)
status_icon = "📈" if improvement > 0 else "📉" if improvement < 0 else "➡️"
output.append(f"{metric:<20} {baseline_val:<12.4f} {finetuned_val:<12.4f} {status_icon} {improvement:>+7.2f}%")
# Performance comparison
output.append(f"\n⚡ PERFORMANCE COMPARISON:")
baseline_perf = results['baseline']['performance']
finetuned_perf = results['finetuned']['performance']
output.append(f"{'Metric':<25} {'Baseline':<15} {'Fine-tuned':<15} {'Change':<15}")
output.append("-" * 75)
perf_metrics = [
('Throughput (samples/s)', 'throughput_samples_per_sec'),
('Latency (ms/sample)', 'latency_ms_per_sample'),
('Memory Usage (MB)', 'memory_usage_mb')
]
for display_name, key in perf_metrics:
baseline_val = baseline_perf.get(key, 0)
finetuned_val = finetuned_perf.get(key, 0)
if baseline_val == 0:
change = 0
else:
change = ((finetuned_val - baseline_val) / baseline_val) * 100
change_icon = "⬆️" if change > 0 else "⬇️" if change < 0 else "➡️"
output.append(f"{display_name:<25} {baseline_val:<15.2f} {finetuned_val:<15.2f} {change_icon} {change:>+7.2f}%")
# Recommendations
output.append(f"\n💡 RECOMMENDATIONS:")
if summary['avg_improvement_pct'] > 5:
output.append(" • Excellent improvement! Model is ready for production.")
elif summary['avg_improvement_pct'] > 2:
output.append(" • Good improvement. Consider additional training or data.")
elif summary['avg_improvement_pct'] > 0:
output.append(" • Modest improvement. Review training hyperparameters.")
else:
output.append(" • Poor improvement. Review data quality and training strategy.")
if summary['positive_improvements'] < summary['total_metrics'] // 2:
output.append(" • Some metrics degraded. Consider ensemble or different approach.")
output.append(f"\n{'='*80}")
# Print to console
for line in output:
print(line)
# Save to file if specified
if output_file:
with open(output_file, 'w', encoding='utf-8') as f:
f.write('\n'.join(output))
logger.info(f"Results saved to: {output_file}")
def parse_args():
parser = argparse.ArgumentParser(description="Compare BGE models against baselines")
# Model type
parser.add_argument("--model_type", type=str, choices=['retriever', 'reranker', 'both'],
required=True, help="Type of model to compare")
# Single model comparison arguments
parser.add_argument("--finetuned_model", type=str, help="Path to fine-tuned model")
parser.add_argument("--baseline_model", type=str, help="Path to baseline model")
parser.add_argument("--data_path", type=str, help="Path to test data")
# Both models comparison arguments
parser.add_argument("--finetuned_retriever", type=str, help="Path to fine-tuned retriever")
parser.add_argument("--finetuned_reranker", type=str, help="Path to fine-tuned reranker")
parser.add_argument("--baseline_retriever", type=str, default="BAAI/bge-m3",
help="Baseline retriever model")
parser.add_argument("--baseline_reranker", type=str, default="BAAI/bge-reranker-base",
help="Baseline reranker model")
parser.add_argument("--retriever_data", type=str, help="Path to retriever test data")
parser.add_argument("--reranker_data", type=str, help="Path to reranker test data")
# Common arguments
parser.add_argument("--batch_size", type=int, default=16, help="Batch size for evaluation")
parser.add_argument("--max_samples", type=int, help="Maximum samples to test (for speed)")
parser.add_argument("--k_values", type=int, nargs='+', default=[1, 5, 10, 20, 100],
help="K values for metrics@k")
parser.add_argument("--device", type=str, default='auto', help="Device to use")
parser.add_argument("--output_dir", type=str, default="./comparison_results",
help="Directory to save results")
return parser.parse_args()
def main():
args = parse_args() # Changed from parser.parse_args() to parse_args()
# Create output directory
Path(args.output_dir).mkdir(parents=True, exist_ok=True)
# Initialize comparator
comparator = ModelComparator(device=args.device)
logger.info("🚀 Starting BGE Model Comparison")
logger.info(f"Model type: {args.model_type}")
if args.model_type == 'retriever':
if not all([args.finetuned_model, args.baseline_model, args.data_path]):
logger.error("For retriever comparison, provide --finetuned_model, --baseline_model, and --data_path")
return 1
results = comparator.compare_retrievers(
args.finetuned_model, args.baseline_model, args.data_path,
batch_size=args.batch_size, max_samples=args.max_samples, k_values=args.k_values
)
output_file = Path(args.output_dir) / "retriever_comparison.txt"
comparator.print_comparison_results(results, str(output_file))
# Save JSON results
with open(Path(args.output_dir) / "retriever_comparison.json", 'w') as f:
json.dump(results, f, indent=2, default=str)
elif args.model_type == 'reranker':
if not all([args.finetuned_model, args.baseline_model, args.data_path]):
logger.error("For reranker comparison, provide --finetuned_model, --baseline_model, and --data_path")
return 1
results = comparator.compare_rerankers(
args.finetuned_model, args.baseline_model, args.data_path,
batch_size=args.batch_size, max_samples=args.max_samples,
k_values=[k for k in args.k_values if k <= 10] # Reranker typically uses smaller K
)
output_file = Path(args.output_dir) / "reranker_comparison.txt"
comparator.print_comparison_results(results, str(output_file))
# Save JSON results
with open(Path(args.output_dir) / "reranker_comparison.json", 'w') as f:
json.dump(results, f, indent=2, default=str)
elif args.model_type == 'both':
if not all([args.finetuned_retriever, args.finetuned_reranker,
args.retriever_data, args.reranker_data]):
logger.error("For both comparison, provide all fine-tuned models and data paths")
return 1
# Compare retriever
logger.info("\n" + "="*50)
logger.info("RETRIEVER COMPARISON")
logger.info("="*50)
retriever_results = comparator.compare_retrievers(
args.finetuned_retriever, args.baseline_retriever, args.retriever_data,
batch_size=args.batch_size, max_samples=args.max_samples, k_values=args.k_values
)
# Compare reranker
logger.info("\n" + "="*50)
logger.info("RERANKER COMPARISON")
logger.info("="*50)
reranker_results = comparator.compare_rerankers(
args.finetuned_reranker, args.baseline_reranker, args.reranker_data,
batch_size=args.batch_size, max_samples=args.max_samples,
k_values=[k for k in args.k_values if k <= 10]
)
# Print results
comparator.print_comparison_results(
retriever_results, str(Path(args.output_dir) / "retriever_comparison.txt")
)
comparator.print_comparison_results(
reranker_results, str(Path(args.output_dir) / "reranker_comparison.txt")
)
# Save combined results
combined_results = {
'retriever': retriever_results,
'reranker': reranker_results,
'overall_summary': {
'retriever_status': retriever_results['summary']['status'],
'reranker_status': reranker_results['summary']['status'],
'avg_retriever_improvement': retriever_results['summary']['avg_improvement_pct'],
'avg_reranker_improvement': reranker_results['summary']['avg_improvement_pct']
}
}
with open(Path(args.output_dir) / "combined_comparison.json", 'w') as f:
json.dump(combined_results, f, indent=2, default=str)
logger.info(f"✅ Comparison completed! Results saved to: {args.output_dir}")
if __name__ == "__main__":
main()