#!/usr/bin/env python3 """ Unified BGE Model Comparison Script This script compares fine-tuned BGE models (retriever/reranker) against baseline models, providing both accuracy metrics and performance comparisons. Usage: # Compare retriever models python scripts/compare_models.py \ --model_type retriever \ --finetuned_model ./output/bge_m3_三国演义/final_model \ --baseline_model BAAI/bge-m3 \ --data_path data/datasets/三国演义/splits/m3_test.jsonl # Compare reranker models python scripts/compare_models.py \ --model_type reranker \ --finetuned_model ./output/bge_reranker_三国演义/final_model \ --baseline_model BAAI/bge-reranker-base \ --data_path data/datasets/三国演义/splits/reranker_test.jsonl # Compare both with comprehensive output python scripts/compare_models.py \ --model_type both \ --finetuned_retriever ./output/bge_m3_三国演义/final_model \ --finetuned_reranker ./output/bge_reranker_三国演义/final_model \ --baseline_retriever BAAI/bge-m3 \ --baseline_reranker BAAI/bge-reranker-base \ --retriever_data data/datasets/三国演义/splits/m3_test.jsonl \ --reranker_data data/datasets/三国演义/splits/reranker_test.jsonl """ import os import sys import time import argparse import logging import json from pathlib import Path from typing import Dict, Tuple, Optional, Any import torch import psutil import numpy as np from transformers import AutoTokenizer from torch.utils.data import DataLoader # Add project root to path sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from models.bge_m3 import BGEM3Model from models.bge_reranker import BGERerankerModel from data.dataset import BGEM3Dataset, BGERerankerDataset from evaluation.metrics import compute_retrieval_metrics, compute_reranker_metrics from utils.config_loader import get_config logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') logger = logging.getLogger(__name__) class ModelComparator: """Unified model comparison class for both retriever and reranker models""" def __init__(self, device: str = 'auto'): self.device = self._get_device(device) def _get_device(self, device: str) -> torch.device: """Get optimal device""" if device == 'auto': try: config = get_config() device_type = config.get('hardware', {}).get('device', 'cuda') if device_type == 'npu' and torch.cuda.is_available(): try: import torch_npu return torch.device("npu:0") except ImportError: pass if torch.cuda.is_available(): return torch.device("cuda") except Exception: pass return torch.device("cpu") else: return torch.device(device) def measure_memory(self) -> float: """Measure current memory usage in MB""" process = psutil.Process(os.getpid()) return process.memory_info().rss / (1024 * 1024) def compare_retrievers( self, finetuned_path: str, baseline_path: str, data_path: str, batch_size: int = 16, max_samples: Optional[int] = None, k_values: list = [1, 5, 10, 20, 100] ) -> Dict[str, Any]: """Compare retriever models""" logger.info("🔍 Comparing Retriever Models") logger.info(f"Fine-tuned: {finetuned_path}") logger.info(f"Baseline: {baseline_path}") # Test both models ft_results = self._benchmark_retriever(finetuned_path, data_path, batch_size, max_samples, k_values) baseline_results = self._benchmark_retriever(baseline_path, data_path, batch_size, max_samples, k_values) # Compute improvements comparison = self._compute_metric_improvements(baseline_results['metrics'], ft_results['metrics']) return { 'model_type': 'retriever', 'finetuned': ft_results, 'baseline': baseline_results, 'comparison': comparison, 'summary': self._generate_retriever_summary(comparison) } def compare_rerankers( self, finetuned_path: str, baseline_path: str, data_path: str, batch_size: int = 32, max_samples: Optional[int] = None, k_values: list = [1, 5, 10] ) -> Dict[str, Any]: """Compare reranker models""" logger.info("🏆 Comparing Reranker Models") logger.info(f"Fine-tuned: {finetuned_path}") logger.info(f"Baseline: {baseline_path}") # Test both models ft_results = self._benchmark_reranker(finetuned_path, data_path, batch_size, max_samples, k_values) baseline_results = self._benchmark_reranker(baseline_path, data_path, batch_size, max_samples, k_values) # Compute improvements comparison = self._compute_metric_improvements(baseline_results['metrics'], ft_results['metrics']) return { 'model_type': 'reranker', 'finetuned': ft_results, 'baseline': baseline_results, 'comparison': comparison, 'summary': self._generate_reranker_summary(comparison) } def _benchmark_retriever( self, model_path: str, data_path: str, batch_size: int, max_samples: Optional[int], k_values: list ) -> Dict[str, Any]: """Benchmark a retriever model""" logger.info(f"Benchmarking retriever: {model_path}") # Load model tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) model = BGEM3Model.from_pretrained(model_path) model.to(self.device) model.eval() # Load data dataset = BGEM3Dataset( data_path=data_path, tokenizer=tokenizer, is_train=False ) if max_samples and len(dataset) > max_samples: dataset.data = dataset.data[:max_samples] dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False) logger.info(f"Testing on {len(dataset)} samples") # Benchmark start_memory = self.measure_memory() start_time = time.time() all_query_embeddings = [] all_passage_embeddings = [] all_labels = [] with torch.no_grad(): for batch in dataloader: # Move to device batch = {k: v.to(self.device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()} # Get embeddings query_outputs = model.forward( input_ids=batch['query_input_ids'], attention_mask=batch['query_attention_mask'], return_dense=True, return_dict=True ) passage_outputs = model.forward( input_ids=batch['passage_input_ids'], attention_mask=batch['passage_attention_mask'], return_dense=True, return_dict=True ) all_query_embeddings.append(query_outputs['dense'].cpu().numpy()) all_passage_embeddings.append(passage_outputs['dense'].cpu().numpy()) all_labels.append(batch['labels'].cpu().numpy()) end_time = time.time() peak_memory = self.measure_memory() # Compute metrics query_embeds = np.concatenate(all_query_embeddings, axis=0) passage_embeds = np.concatenate(all_passage_embeddings, axis=0) labels = np.concatenate(all_labels, axis=0) metrics = compute_retrieval_metrics( query_embeds, passage_embeds, labels, k_values=k_values ) # Performance stats total_time = end_time - start_time throughput = len(dataset) / total_time latency = (total_time * 1000) / len(dataset) # ms per sample memory_usage = peak_memory - start_memory return { 'model_path': model_path, 'metrics': metrics, 'performance': { 'throughput_samples_per_sec': throughput, 'latency_ms_per_sample': latency, 'memory_usage_mb': memory_usage, 'total_time_sec': total_time, 'samples_processed': len(dataset) } } def _benchmark_reranker( self, model_path: str, data_path: str, batch_size: int, max_samples: Optional[int], k_values: list ) -> Dict[str, Any]: """Benchmark a reranker model""" logger.info(f"Benchmarking reranker: {model_path}") # Load model tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) model = BGERerankerModel.from_pretrained(model_path) model.to(self.device) model.eval() # Load data dataset = BGERerankerDataset( data_path=data_path, tokenizer=tokenizer, is_train=False ) if max_samples and len(dataset) > max_samples: dataset.data = dataset.data[:max_samples] dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False) logger.info(f"Testing on {len(dataset)} samples") # Benchmark start_memory = self.measure_memory() start_time = time.time() all_scores = [] all_labels = [] with torch.no_grad(): for batch in dataloader: # Move to device batch = {k: v.to(self.device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()} # Get scores outputs = model.forward( input_ids=batch['input_ids'], attention_mask=batch['attention_mask'], return_dict=True ) logits = outputs.logits if hasattr(outputs, 'logits') else outputs.scores scores = torch.sigmoid(logits.squeeze()).cpu().numpy() all_scores.append(scores) all_labels.append(batch['labels'].cpu().numpy()) end_time = time.time() peak_memory = self.measure_memory() # Compute metrics scores = np.concatenate(all_scores, axis=0) labels = np.concatenate(all_labels, axis=0) metrics = compute_reranker_metrics(scores, labels, k_values=k_values) # Performance stats total_time = end_time - start_time throughput = len(dataset) / total_time latency = (total_time * 1000) / len(dataset) # ms per sample memory_usage = peak_memory - start_memory return { 'model_path': model_path, 'metrics': metrics, 'performance': { 'throughput_samples_per_sec': throughput, 'latency_ms_per_sample': latency, 'memory_usage_mb': memory_usage, 'total_time_sec': total_time, 'samples_processed': len(dataset) } } def _compute_metric_improvements(self, baseline_metrics: Dict, finetuned_metrics: Dict) -> Dict[str, float]: """Compute percentage improvements for each metric""" improvements = {} for metric_name in baseline_metrics: baseline_val = baseline_metrics[metric_name] finetuned_val = finetuned_metrics.get(metric_name, baseline_val) if baseline_val == 0: improvement = 0.0 if finetuned_val == 0 else float('inf') else: improvement = ((finetuned_val - baseline_val) / baseline_val) * 100 improvements[metric_name] = improvement return improvements def _generate_retriever_summary(self, comparison: Dict[str, float]) -> Dict[str, Any]: """Generate summary for retriever comparison""" key_metrics = ['recall@5', 'recall@10', 'map', 'mrr'] improvements = [comparison.get(metric, 0) for metric in key_metrics if metric in comparison] avg_improvement = np.mean(improvements) if improvements else 0 positive_improvements = sum(1 for imp in improvements if imp > 0) if avg_improvement > 5: status = "🌟 EXCELLENT" elif avg_improvement > 2: status = "✅ GOOD" elif avg_improvement > 0: status = "👌 FAIR" else: status = "❌ POOR" return { 'status': status, 'avg_improvement_pct': avg_improvement, 'positive_improvements': positive_improvements, 'total_metrics': len(improvements), 'key_improvements': {metric: comparison.get(metric, 0) for metric in key_metrics if metric in comparison} } def _generate_reranker_summary(self, comparison: Dict[str, float]) -> Dict[str, Any]: """Generate summary for reranker comparison""" key_metrics = ['accuracy', 'precision', 'recall', 'f1', 'auc'] improvements = [comparison.get(metric, 0) for metric in key_metrics if metric in comparison] avg_improvement = np.mean(improvements) if improvements else 0 positive_improvements = sum(1 for imp in improvements if imp > 0) if avg_improvement > 3: status = "🌟 EXCELLENT" elif avg_improvement > 1: status = "✅ GOOD" elif avg_improvement > 0: status = "👌 FAIR" else: status = "❌ POOR" return { 'status': status, 'avg_improvement_pct': avg_improvement, 'positive_improvements': positive_improvements, 'total_metrics': len(improvements), 'key_improvements': {metric: comparison.get(metric, 0) for metric in key_metrics if metric in comparison} } def print_comparison_results(self, results: Dict[str, Any], output_file: Optional[str] = None): """Print formatted comparison results""" model_type = results['model_type'].upper() summary = results['summary'] output = [] output.append(f"\n{'='*80}") output.append(f"📊 {model_type} MODEL COMPARISON RESULTS") output.append(f"{'='*80}") output.append(f"\n🎯 OVERALL ASSESSMENT: {summary['status']}") output.append(f"📈 Average Improvement: {summary['avg_improvement_pct']:.2f}%") output.append(f"✅ Metrics Improved: {summary['positive_improvements']}/{summary['total_metrics']}") # Key metrics comparison output.append(f"\n📋 KEY METRICS COMPARISON:") output.append(f"{'Metric':<20} {'Baseline':<12} {'Fine-tuned':<12} {'Improvement':<12}") output.append("-" * 60) for metric, improvement in summary['key_improvements'].items(): baseline_val = results['baseline']['metrics'].get(metric, 0) finetuned_val = results['finetuned']['metrics'].get(metric, 0) status_icon = "📈" if improvement > 0 else "📉" if improvement < 0 else "➡️" output.append(f"{metric:<20} {baseline_val:<12.4f} {finetuned_val:<12.4f} {status_icon} {improvement:>+7.2f}%") # Performance comparison output.append(f"\n⚡ PERFORMANCE COMPARISON:") baseline_perf = results['baseline']['performance'] finetuned_perf = results['finetuned']['performance'] output.append(f"{'Metric':<25} {'Baseline':<15} {'Fine-tuned':<15} {'Change':<15}") output.append("-" * 75) perf_metrics = [ ('Throughput (samples/s)', 'throughput_samples_per_sec'), ('Latency (ms/sample)', 'latency_ms_per_sample'), ('Memory Usage (MB)', 'memory_usage_mb') ] for display_name, key in perf_metrics: baseline_val = baseline_perf.get(key, 0) finetuned_val = finetuned_perf.get(key, 0) if baseline_val == 0: change = 0 else: change = ((finetuned_val - baseline_val) / baseline_val) * 100 change_icon = "⬆️" if change > 0 else "⬇️" if change < 0 else "➡️" output.append(f"{display_name:<25} {baseline_val:<15.2f} {finetuned_val:<15.2f} {change_icon} {change:>+7.2f}%") # Recommendations output.append(f"\n💡 RECOMMENDATIONS:") if summary['avg_improvement_pct'] > 5: output.append(" • Excellent improvement! Model is ready for production.") elif summary['avg_improvement_pct'] > 2: output.append(" • Good improvement. Consider additional training or data.") elif summary['avg_improvement_pct'] > 0: output.append(" • Modest improvement. Review training hyperparameters.") else: output.append(" • Poor improvement. Review data quality and training strategy.") if summary['positive_improvements'] < summary['total_metrics'] // 2: output.append(" • Some metrics degraded. Consider ensemble or different approach.") output.append(f"\n{'='*80}") # Print to console for line in output: print(line) # Save to file if specified if output_file: with open(output_file, 'w', encoding='utf-8') as f: f.write('\n'.join(output)) logger.info(f"Results saved to: {output_file}") def parse_args(): parser = argparse.ArgumentParser(description="Compare BGE models against baselines") # Model type parser.add_argument("--model_type", type=str, choices=['retriever', 'reranker', 'both'], required=True, help="Type of model to compare") # Single model comparison arguments parser.add_argument("--finetuned_model", type=str, help="Path to fine-tuned model") parser.add_argument("--baseline_model", type=str, help="Path to baseline model") parser.add_argument("--data_path", type=str, help="Path to test data") # Both models comparison arguments parser.add_argument("--finetuned_retriever", type=str, help="Path to fine-tuned retriever") parser.add_argument("--finetuned_reranker", type=str, help="Path to fine-tuned reranker") parser.add_argument("--baseline_retriever", type=str, default="BAAI/bge-m3", help="Baseline retriever model") parser.add_argument("--baseline_reranker", type=str, default="BAAI/bge-reranker-base", help="Baseline reranker model") parser.add_argument("--retriever_data", type=str, help="Path to retriever test data") parser.add_argument("--reranker_data", type=str, help="Path to reranker test data") # Common arguments parser.add_argument("--batch_size", type=int, default=16, help="Batch size for evaluation") parser.add_argument("--max_samples", type=int, help="Maximum samples to test (for speed)") parser.add_argument("--k_values", type=int, nargs='+', default=[1, 5, 10, 20, 100], help="K values for metrics@k") parser.add_argument("--device", type=str, default='auto', help="Device to use") parser.add_argument("--output_dir", type=str, default="./comparison_results", help="Directory to save results") return parser.parse_args() def main(): args = parse_args() # Changed from parser.parse_args() to parse_args() # Create output directory Path(args.output_dir).mkdir(parents=True, exist_ok=True) # Initialize comparator comparator = ModelComparator(device=args.device) logger.info("🚀 Starting BGE Model Comparison") logger.info(f"Model type: {args.model_type}") if args.model_type == 'retriever': if not all([args.finetuned_model, args.baseline_model, args.data_path]): logger.error("For retriever comparison, provide --finetuned_model, --baseline_model, and --data_path") return 1 results = comparator.compare_retrievers( args.finetuned_model, args.baseline_model, args.data_path, batch_size=args.batch_size, max_samples=args.max_samples, k_values=args.k_values ) output_file = Path(args.output_dir) / "retriever_comparison.txt" comparator.print_comparison_results(results, str(output_file)) # Save JSON results with open(Path(args.output_dir) / "retriever_comparison.json", 'w') as f: json.dump(results, f, indent=2, default=str) elif args.model_type == 'reranker': if not all([args.finetuned_model, args.baseline_model, args.data_path]): logger.error("For reranker comparison, provide --finetuned_model, --baseline_model, and --data_path") return 1 results = comparator.compare_rerankers( args.finetuned_model, args.baseline_model, args.data_path, batch_size=args.batch_size, max_samples=args.max_samples, k_values=[k for k in args.k_values if k <= 10] # Reranker typically uses smaller K ) output_file = Path(args.output_dir) / "reranker_comparison.txt" comparator.print_comparison_results(results, str(output_file)) # Save JSON results with open(Path(args.output_dir) / "reranker_comparison.json", 'w') as f: json.dump(results, f, indent=2, default=str) elif args.model_type == 'both': if not all([args.finetuned_retriever, args.finetuned_reranker, args.retriever_data, args.reranker_data]): logger.error("For both comparison, provide all fine-tuned models and data paths") return 1 # Compare retriever logger.info("\n" + "="*50) logger.info("RETRIEVER COMPARISON") logger.info("="*50) retriever_results = comparator.compare_retrievers( args.finetuned_retriever, args.baseline_retriever, args.retriever_data, batch_size=args.batch_size, max_samples=args.max_samples, k_values=args.k_values ) # Compare reranker logger.info("\n" + "="*50) logger.info("RERANKER COMPARISON") logger.info("="*50) reranker_results = comparator.compare_rerankers( args.finetuned_reranker, args.baseline_reranker, args.reranker_data, batch_size=args.batch_size, max_samples=args.max_samples, k_values=[k for k in args.k_values if k <= 10] ) # Print results comparator.print_comparison_results( retriever_results, str(Path(args.output_dir) / "retriever_comparison.txt") ) comparator.print_comparison_results( reranker_results, str(Path(args.output_dir) / "reranker_comparison.txt") ) # Save combined results combined_results = { 'retriever': retriever_results, 'reranker': reranker_results, 'overall_summary': { 'retriever_status': retriever_results['summary']['status'], 'reranker_status': reranker_results['summary']['status'], 'avg_retriever_improvement': retriever_results['summary']['avg_improvement_pct'], 'avg_reranker_improvement': reranker_results['summary']['avg_improvement_pct'] } } with open(Path(args.output_dir) / "combined_comparison.json", 'w') as f: json.dump(combined_results, f, indent=2, default=str) logger.info(f"✅ Comparison completed! Results saved to: {args.output_dir}") if __name__ == "__main__": main()