Revised for training
This commit is contained in:
611
scripts/compare_models.py
Normal file
611
scripts/compare_models.py
Normal file
@@ -0,0 +1,611 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Unified BGE Model Comparison Script
|
||||
|
||||
This script compares fine-tuned BGE models (retriever/reranker) against baseline models,
|
||||
providing both accuracy metrics and performance comparisons.
|
||||
|
||||
Usage:
|
||||
# Compare retriever models
|
||||
python scripts/compare_models.py \
|
||||
--model_type retriever \
|
||||
--finetuned_model ./output/bge_m3_三国演义/final_model \
|
||||
--baseline_model BAAI/bge-m3 \
|
||||
--data_path data/datasets/三国演义/splits/m3_test.jsonl
|
||||
|
||||
# Compare reranker models
|
||||
python scripts/compare_models.py \
|
||||
--model_type reranker \
|
||||
--finetuned_model ./output/bge_reranker_三国演义/final_model \
|
||||
--baseline_model BAAI/bge-reranker-base \
|
||||
--data_path data/datasets/三国演义/splits/reranker_test.jsonl
|
||||
|
||||
# Compare both with comprehensive output
|
||||
python scripts/compare_models.py \
|
||||
--model_type both \
|
||||
--finetuned_retriever ./output/bge_m3_三国演义/final_model \
|
||||
--finetuned_reranker ./output/bge_reranker_三国演义/final_model \
|
||||
--baseline_retriever BAAI/bge-m3 \
|
||||
--baseline_reranker BAAI/bge-reranker-base \
|
||||
--retriever_data data/datasets/三国演义/splits/m3_test.jsonl \
|
||||
--reranker_data data/datasets/三国演义/splits/reranker_test.jsonl
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import argparse
|
||||
import logging
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Dict, Tuple, Optional, Any
|
||||
import torch
|
||||
import psutil
|
||||
import numpy as np
|
||||
from transformers import AutoTokenizer
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
# Add project root to path
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from models.bge_m3 import BGEM3Model
|
||||
from models.bge_reranker import BGERerankerModel
|
||||
from data.dataset import BGEM3Dataset, BGERerankerDataset
|
||||
from evaluation.metrics import compute_retrieval_metrics, compute_reranker_metrics
|
||||
from utils.config_loader import get_config
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ModelComparator:
|
||||
"""Unified model comparison class for both retriever and reranker models"""
|
||||
|
||||
def __init__(self, device: str = 'auto'):
|
||||
self.device = self._get_device(device)
|
||||
|
||||
def _get_device(self, device: str) -> torch.device:
|
||||
"""Get optimal device"""
|
||||
if device == 'auto':
|
||||
try:
|
||||
config = get_config()
|
||||
device_type = config.get('hardware', {}).get('device', 'cuda')
|
||||
|
||||
if device_type == 'npu' and torch.cuda.is_available():
|
||||
try:
|
||||
import torch_npu
|
||||
return torch.device("npu:0")
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
if torch.cuda.is_available():
|
||||
return torch.device("cuda")
|
||||
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return torch.device("cpu")
|
||||
else:
|
||||
return torch.device(device)
|
||||
|
||||
def measure_memory(self) -> float:
|
||||
"""Measure current memory usage in MB"""
|
||||
process = psutil.Process(os.getpid())
|
||||
return process.memory_info().rss / (1024 * 1024)
|
||||
|
||||
def compare_retrievers(
|
||||
self,
|
||||
finetuned_path: str,
|
||||
baseline_path: str,
|
||||
data_path: str,
|
||||
batch_size: int = 16,
|
||||
max_samples: Optional[int] = None,
|
||||
k_values: list = [1, 5, 10, 20, 100]
|
||||
) -> Dict[str, Any]:
|
||||
"""Compare retriever models"""
|
||||
logger.info("🔍 Comparing Retriever Models")
|
||||
logger.info(f"Fine-tuned: {finetuned_path}")
|
||||
logger.info(f"Baseline: {baseline_path}")
|
||||
|
||||
# Test both models
|
||||
ft_results = self._benchmark_retriever(finetuned_path, data_path, batch_size, max_samples, k_values)
|
||||
baseline_results = self._benchmark_retriever(baseline_path, data_path, batch_size, max_samples, k_values)
|
||||
|
||||
# Compute improvements
|
||||
comparison = self._compute_metric_improvements(baseline_results['metrics'], ft_results['metrics'])
|
||||
|
||||
return {
|
||||
'model_type': 'retriever',
|
||||
'finetuned': ft_results,
|
||||
'baseline': baseline_results,
|
||||
'comparison': comparison,
|
||||
'summary': self._generate_retriever_summary(comparison)
|
||||
}
|
||||
|
||||
def compare_rerankers(
|
||||
self,
|
||||
finetuned_path: str,
|
||||
baseline_path: str,
|
||||
data_path: str,
|
||||
batch_size: int = 32,
|
||||
max_samples: Optional[int] = None,
|
||||
k_values: list = [1, 5, 10]
|
||||
) -> Dict[str, Any]:
|
||||
"""Compare reranker models"""
|
||||
logger.info("🏆 Comparing Reranker Models")
|
||||
logger.info(f"Fine-tuned: {finetuned_path}")
|
||||
logger.info(f"Baseline: {baseline_path}")
|
||||
|
||||
# Test both models
|
||||
ft_results = self._benchmark_reranker(finetuned_path, data_path, batch_size, max_samples, k_values)
|
||||
baseline_results = self._benchmark_reranker(baseline_path, data_path, batch_size, max_samples, k_values)
|
||||
|
||||
# Compute improvements
|
||||
comparison = self._compute_metric_improvements(baseline_results['metrics'], ft_results['metrics'])
|
||||
|
||||
return {
|
||||
'model_type': 'reranker',
|
||||
'finetuned': ft_results,
|
||||
'baseline': baseline_results,
|
||||
'comparison': comparison,
|
||||
'summary': self._generate_reranker_summary(comparison)
|
||||
}
|
||||
|
||||
def _benchmark_retriever(
|
||||
self,
|
||||
model_path: str,
|
||||
data_path: str,
|
||||
batch_size: int,
|
||||
max_samples: Optional[int],
|
||||
k_values: list
|
||||
) -> Dict[str, Any]:
|
||||
"""Benchmark a retriever model"""
|
||||
logger.info(f"Benchmarking retriever: {model_path}")
|
||||
|
||||
# Load model
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
||||
model = BGEM3Model.from_pretrained(model_path)
|
||||
model.to(self.device)
|
||||
model.eval()
|
||||
|
||||
# Load data
|
||||
dataset = BGEM3Dataset(
|
||||
data_path=data_path,
|
||||
tokenizer=tokenizer,
|
||||
is_train=False
|
||||
)
|
||||
|
||||
if max_samples and len(dataset) > max_samples:
|
||||
dataset.data = dataset.data[:max_samples]
|
||||
|
||||
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
|
||||
logger.info(f"Testing on {len(dataset)} samples")
|
||||
|
||||
# Benchmark
|
||||
start_memory = self.measure_memory()
|
||||
start_time = time.time()
|
||||
|
||||
all_query_embeddings = []
|
||||
all_passage_embeddings = []
|
||||
all_labels = []
|
||||
|
||||
with torch.no_grad():
|
||||
for batch in dataloader:
|
||||
# Move to device
|
||||
batch = {k: v.to(self.device) if isinstance(v, torch.Tensor) else v
|
||||
for k, v in batch.items()}
|
||||
|
||||
# Get embeddings
|
||||
query_outputs = model.forward(
|
||||
input_ids=batch['query_input_ids'],
|
||||
attention_mask=batch['query_attention_mask'],
|
||||
return_dense=True,
|
||||
return_dict=True
|
||||
)
|
||||
|
||||
passage_outputs = model.forward(
|
||||
input_ids=batch['passage_input_ids'],
|
||||
attention_mask=batch['passage_attention_mask'],
|
||||
return_dense=True,
|
||||
return_dict=True
|
||||
)
|
||||
|
||||
all_query_embeddings.append(query_outputs['dense'].cpu().numpy())
|
||||
all_passage_embeddings.append(passage_outputs['dense'].cpu().numpy())
|
||||
all_labels.append(batch['labels'].cpu().numpy())
|
||||
|
||||
end_time = time.time()
|
||||
peak_memory = self.measure_memory()
|
||||
|
||||
# Compute metrics
|
||||
query_embeds = np.concatenate(all_query_embeddings, axis=0)
|
||||
passage_embeds = np.concatenate(all_passage_embeddings, axis=0)
|
||||
labels = np.concatenate(all_labels, axis=0)
|
||||
|
||||
metrics = compute_retrieval_metrics(
|
||||
query_embeds, passage_embeds, labels, k_values=k_values
|
||||
)
|
||||
|
||||
# Performance stats
|
||||
total_time = end_time - start_time
|
||||
throughput = len(dataset) / total_time
|
||||
latency = (total_time * 1000) / len(dataset) # ms per sample
|
||||
memory_usage = peak_memory - start_memory
|
||||
|
||||
return {
|
||||
'model_path': model_path,
|
||||
'metrics': metrics,
|
||||
'performance': {
|
||||
'throughput_samples_per_sec': throughput,
|
||||
'latency_ms_per_sample': latency,
|
||||
'memory_usage_mb': memory_usage,
|
||||
'total_time_sec': total_time,
|
||||
'samples_processed': len(dataset)
|
||||
}
|
||||
}
|
||||
|
||||
def _benchmark_reranker(
|
||||
self,
|
||||
model_path: str,
|
||||
data_path: str,
|
||||
batch_size: int,
|
||||
max_samples: Optional[int],
|
||||
k_values: list
|
||||
) -> Dict[str, Any]:
|
||||
"""Benchmark a reranker model"""
|
||||
logger.info(f"Benchmarking reranker: {model_path}")
|
||||
|
||||
# Load model
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
||||
model = BGERerankerModel.from_pretrained(model_path)
|
||||
model.to(self.device)
|
||||
model.eval()
|
||||
|
||||
# Load data
|
||||
dataset = BGERerankerDataset(
|
||||
data_path=data_path,
|
||||
tokenizer=tokenizer,
|
||||
is_train=False
|
||||
)
|
||||
|
||||
if max_samples and len(dataset) > max_samples:
|
||||
dataset.data = dataset.data[:max_samples]
|
||||
|
||||
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
|
||||
logger.info(f"Testing on {len(dataset)} samples")
|
||||
|
||||
# Benchmark
|
||||
start_memory = self.measure_memory()
|
||||
start_time = time.time()
|
||||
|
||||
all_scores = []
|
||||
all_labels = []
|
||||
|
||||
with torch.no_grad():
|
||||
for batch in dataloader:
|
||||
# Move to device
|
||||
batch = {k: v.to(self.device) if isinstance(v, torch.Tensor) else v
|
||||
for k, v in batch.items()}
|
||||
|
||||
# Get scores
|
||||
outputs = model.forward(
|
||||
input_ids=batch['input_ids'],
|
||||
attention_mask=batch['attention_mask'],
|
||||
return_dict=True
|
||||
)
|
||||
|
||||
logits = outputs.logits if hasattr(outputs, 'logits') else outputs.scores
|
||||
scores = torch.sigmoid(logits.squeeze()).cpu().numpy()
|
||||
|
||||
all_scores.append(scores)
|
||||
all_labels.append(batch['labels'].cpu().numpy())
|
||||
|
||||
end_time = time.time()
|
||||
peak_memory = self.measure_memory()
|
||||
|
||||
# Compute metrics
|
||||
scores = np.concatenate(all_scores, axis=0)
|
||||
labels = np.concatenate(all_labels, axis=0)
|
||||
|
||||
metrics = compute_reranker_metrics(scores, labels, k_values=k_values)
|
||||
|
||||
# Performance stats
|
||||
total_time = end_time - start_time
|
||||
throughput = len(dataset) / total_time
|
||||
latency = (total_time * 1000) / len(dataset) # ms per sample
|
||||
memory_usage = peak_memory - start_memory
|
||||
|
||||
return {
|
||||
'model_path': model_path,
|
||||
'metrics': metrics,
|
||||
'performance': {
|
||||
'throughput_samples_per_sec': throughput,
|
||||
'latency_ms_per_sample': latency,
|
||||
'memory_usage_mb': memory_usage,
|
||||
'total_time_sec': total_time,
|
||||
'samples_processed': len(dataset)
|
||||
}
|
||||
}
|
||||
|
||||
def _compute_metric_improvements(self, baseline_metrics: Dict, finetuned_metrics: Dict) -> Dict[str, float]:
|
||||
"""Compute percentage improvements for each metric"""
|
||||
improvements = {}
|
||||
|
||||
for metric_name in baseline_metrics:
|
||||
baseline_val = baseline_metrics[metric_name]
|
||||
finetuned_val = finetuned_metrics.get(metric_name, baseline_val)
|
||||
|
||||
if baseline_val == 0:
|
||||
improvement = 0.0 if finetuned_val == 0 else float('inf')
|
||||
else:
|
||||
improvement = ((finetuned_val - baseline_val) / baseline_val) * 100
|
||||
|
||||
improvements[metric_name] = improvement
|
||||
|
||||
return improvements
|
||||
|
||||
def _generate_retriever_summary(self, comparison: Dict[str, float]) -> Dict[str, Any]:
|
||||
"""Generate summary for retriever comparison"""
|
||||
key_metrics = ['recall@5', 'recall@10', 'map', 'mrr']
|
||||
improvements = [comparison.get(metric, 0) for metric in key_metrics if metric in comparison]
|
||||
|
||||
avg_improvement = np.mean(improvements) if improvements else 0
|
||||
positive_improvements = sum(1 for imp in improvements if imp > 0)
|
||||
|
||||
if avg_improvement > 5:
|
||||
status = "🌟 EXCELLENT"
|
||||
elif avg_improvement > 2:
|
||||
status = "✅ GOOD"
|
||||
elif avg_improvement > 0:
|
||||
status = "👌 FAIR"
|
||||
else:
|
||||
status = "❌ POOR"
|
||||
|
||||
return {
|
||||
'status': status,
|
||||
'avg_improvement_pct': avg_improvement,
|
||||
'positive_improvements': positive_improvements,
|
||||
'total_metrics': len(improvements),
|
||||
'key_improvements': {metric: comparison.get(metric, 0) for metric in key_metrics if metric in comparison}
|
||||
}
|
||||
|
||||
def _generate_reranker_summary(self, comparison: Dict[str, float]) -> Dict[str, Any]:
|
||||
"""Generate summary for reranker comparison"""
|
||||
key_metrics = ['accuracy', 'precision', 'recall', 'f1', 'auc']
|
||||
improvements = [comparison.get(metric, 0) for metric in key_metrics if metric in comparison]
|
||||
|
||||
avg_improvement = np.mean(improvements) if improvements else 0
|
||||
positive_improvements = sum(1 for imp in improvements if imp > 0)
|
||||
|
||||
if avg_improvement > 3:
|
||||
status = "🌟 EXCELLENT"
|
||||
elif avg_improvement > 1:
|
||||
status = "✅ GOOD"
|
||||
elif avg_improvement > 0:
|
||||
status = "👌 FAIR"
|
||||
else:
|
||||
status = "❌ POOR"
|
||||
|
||||
return {
|
||||
'status': status,
|
||||
'avg_improvement_pct': avg_improvement,
|
||||
'positive_improvements': positive_improvements,
|
||||
'total_metrics': len(improvements),
|
||||
'key_improvements': {metric: comparison.get(metric, 0) for metric in key_metrics if metric in comparison}
|
||||
}
|
||||
|
||||
def print_comparison_results(self, results: Dict[str, Any], output_file: Optional[str] = None):
|
||||
"""Print formatted comparison results"""
|
||||
model_type = results['model_type'].upper()
|
||||
summary = results['summary']
|
||||
|
||||
output = []
|
||||
output.append(f"\n{'='*80}")
|
||||
output.append(f"📊 {model_type} MODEL COMPARISON RESULTS")
|
||||
output.append(f"{'='*80}")
|
||||
|
||||
output.append(f"\n🎯 OVERALL ASSESSMENT: {summary['status']}")
|
||||
output.append(f"📈 Average Improvement: {summary['avg_improvement_pct']:.2f}%")
|
||||
output.append(f"✅ Metrics Improved: {summary['positive_improvements']}/{summary['total_metrics']}")
|
||||
|
||||
# Key metrics comparison
|
||||
output.append(f"\n📋 KEY METRICS COMPARISON:")
|
||||
output.append(f"{'Metric':<20} {'Baseline':<12} {'Fine-tuned':<12} {'Improvement':<12}")
|
||||
output.append("-" * 60)
|
||||
|
||||
for metric, improvement in summary['key_improvements'].items():
|
||||
baseline_val = results['baseline']['metrics'].get(metric, 0)
|
||||
finetuned_val = results['finetuned']['metrics'].get(metric, 0)
|
||||
|
||||
status_icon = "📈" if improvement > 0 else "📉" if improvement < 0 else "➡️"
|
||||
output.append(f"{metric:<20} {baseline_val:<12.4f} {finetuned_val:<12.4f} {status_icon} {improvement:>+7.2f}%")
|
||||
|
||||
# Performance comparison
|
||||
output.append(f"\n⚡ PERFORMANCE COMPARISON:")
|
||||
baseline_perf = results['baseline']['performance']
|
||||
finetuned_perf = results['finetuned']['performance']
|
||||
|
||||
output.append(f"{'Metric':<25} {'Baseline':<15} {'Fine-tuned':<15} {'Change':<15}")
|
||||
output.append("-" * 75)
|
||||
|
||||
perf_metrics = [
|
||||
('Throughput (samples/s)', 'throughput_samples_per_sec'),
|
||||
('Latency (ms/sample)', 'latency_ms_per_sample'),
|
||||
('Memory Usage (MB)', 'memory_usage_mb')
|
||||
]
|
||||
|
||||
for display_name, key in perf_metrics:
|
||||
baseline_val = baseline_perf.get(key, 0)
|
||||
finetuned_val = finetuned_perf.get(key, 0)
|
||||
|
||||
if baseline_val == 0:
|
||||
change = 0
|
||||
else:
|
||||
change = ((finetuned_val - baseline_val) / baseline_val) * 100
|
||||
|
||||
change_icon = "⬆️" if change > 0 else "⬇️" if change < 0 else "➡️"
|
||||
output.append(f"{display_name:<25} {baseline_val:<15.2f} {finetuned_val:<15.2f} {change_icon} {change:>+7.2f}%")
|
||||
|
||||
# Recommendations
|
||||
output.append(f"\n💡 RECOMMENDATIONS:")
|
||||
if summary['avg_improvement_pct'] > 5:
|
||||
output.append(" • Excellent improvement! Model is ready for production.")
|
||||
elif summary['avg_improvement_pct'] > 2:
|
||||
output.append(" • Good improvement. Consider additional training or data.")
|
||||
elif summary['avg_improvement_pct'] > 0:
|
||||
output.append(" • Modest improvement. Review training hyperparameters.")
|
||||
else:
|
||||
output.append(" • Poor improvement. Review data quality and training strategy.")
|
||||
|
||||
if summary['positive_improvements'] < summary['total_metrics'] // 2:
|
||||
output.append(" • Some metrics degraded. Consider ensemble or different approach.")
|
||||
|
||||
output.append(f"\n{'='*80}")
|
||||
|
||||
# Print to console
|
||||
for line in output:
|
||||
print(line)
|
||||
|
||||
# Save to file if specified
|
||||
if output_file:
|
||||
with open(output_file, 'w', encoding='utf-8') as f:
|
||||
f.write('\n'.join(output))
|
||||
logger.info(f"Results saved to: {output_file}")
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description="Compare BGE models against baselines")
|
||||
|
||||
# Model type
|
||||
parser.add_argument("--model_type", type=str, choices=['retriever', 'reranker', 'both'],
|
||||
required=True, help="Type of model to compare")
|
||||
|
||||
# Single model comparison arguments
|
||||
parser.add_argument("--finetuned_model", type=str, help="Path to fine-tuned model")
|
||||
parser.add_argument("--baseline_model", type=str, help="Path to baseline model")
|
||||
parser.add_argument("--data_path", type=str, help="Path to test data")
|
||||
|
||||
# Both models comparison arguments
|
||||
parser.add_argument("--finetuned_retriever", type=str, help="Path to fine-tuned retriever")
|
||||
parser.add_argument("--finetuned_reranker", type=str, help="Path to fine-tuned reranker")
|
||||
parser.add_argument("--baseline_retriever", type=str, default="BAAI/bge-m3",
|
||||
help="Baseline retriever model")
|
||||
parser.add_argument("--baseline_reranker", type=str, default="BAAI/bge-reranker-base",
|
||||
help="Baseline reranker model")
|
||||
parser.add_argument("--retriever_data", type=str, help="Path to retriever test data")
|
||||
parser.add_argument("--reranker_data", type=str, help="Path to reranker test data")
|
||||
|
||||
# Common arguments
|
||||
parser.add_argument("--batch_size", type=int, default=16, help="Batch size for evaluation")
|
||||
parser.add_argument("--max_samples", type=int, help="Maximum samples to test (for speed)")
|
||||
parser.add_argument("--k_values", type=int, nargs='+', default=[1, 5, 10, 20, 100],
|
||||
help="K values for metrics@k")
|
||||
parser.add_argument("--device", type=str, default='auto', help="Device to use")
|
||||
parser.add_argument("--output_dir", type=str, default="./comparison_results",
|
||||
help="Directory to save results")
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args() # Changed from parser.parse_args() to parse_args()
|
||||
|
||||
# Create output directory
|
||||
Path(args.output_dir).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Initialize comparator
|
||||
comparator = ModelComparator(device=args.device)
|
||||
|
||||
logger.info("🚀 Starting BGE Model Comparison")
|
||||
logger.info(f"Model type: {args.model_type}")
|
||||
|
||||
if args.model_type == 'retriever':
|
||||
if not all([args.finetuned_model, args.baseline_model, args.data_path]):
|
||||
logger.error("For retriever comparison, provide --finetuned_model, --baseline_model, and --data_path")
|
||||
return 1
|
||||
|
||||
results = comparator.compare_retrievers(
|
||||
args.finetuned_model, args.baseline_model, args.data_path,
|
||||
batch_size=args.batch_size, max_samples=args.max_samples, k_values=args.k_values
|
||||
)
|
||||
|
||||
output_file = Path(args.output_dir) / "retriever_comparison.txt"
|
||||
comparator.print_comparison_results(results, str(output_file))
|
||||
|
||||
# Save JSON results
|
||||
with open(Path(args.output_dir) / "retriever_comparison.json", 'w') as f:
|
||||
json.dump(results, f, indent=2, default=str)
|
||||
|
||||
elif args.model_type == 'reranker':
|
||||
if not all([args.finetuned_model, args.baseline_model, args.data_path]):
|
||||
logger.error("For reranker comparison, provide --finetuned_model, --baseline_model, and --data_path")
|
||||
return 1
|
||||
|
||||
results = comparator.compare_rerankers(
|
||||
args.finetuned_model, args.baseline_model, args.data_path,
|
||||
batch_size=args.batch_size, max_samples=args.max_samples,
|
||||
k_values=[k for k in args.k_values if k <= 10] # Reranker typically uses smaller K
|
||||
)
|
||||
|
||||
output_file = Path(args.output_dir) / "reranker_comparison.txt"
|
||||
comparator.print_comparison_results(results, str(output_file))
|
||||
|
||||
# Save JSON results
|
||||
with open(Path(args.output_dir) / "reranker_comparison.json", 'w') as f:
|
||||
json.dump(results, f, indent=2, default=str)
|
||||
|
||||
elif args.model_type == 'both':
|
||||
if not all([args.finetuned_retriever, args.finetuned_reranker,
|
||||
args.retriever_data, args.reranker_data]):
|
||||
logger.error("For both comparison, provide all fine-tuned models and data paths")
|
||||
return 1
|
||||
|
||||
# Compare retriever
|
||||
logger.info("\n" + "="*50)
|
||||
logger.info("RETRIEVER COMPARISON")
|
||||
logger.info("="*50)
|
||||
|
||||
retriever_results = comparator.compare_retrievers(
|
||||
args.finetuned_retriever, args.baseline_retriever, args.retriever_data,
|
||||
batch_size=args.batch_size, max_samples=args.max_samples, k_values=args.k_values
|
||||
)
|
||||
|
||||
# Compare reranker
|
||||
logger.info("\n" + "="*50)
|
||||
logger.info("RERANKER COMPARISON")
|
||||
logger.info("="*50)
|
||||
|
||||
reranker_results = comparator.compare_rerankers(
|
||||
args.finetuned_reranker, args.baseline_reranker, args.reranker_data,
|
||||
batch_size=args.batch_size, max_samples=args.max_samples,
|
||||
k_values=[k for k in args.k_values if k <= 10]
|
||||
)
|
||||
|
||||
# Print results
|
||||
comparator.print_comparison_results(
|
||||
retriever_results, str(Path(args.output_dir) / "retriever_comparison.txt")
|
||||
)
|
||||
comparator.print_comparison_results(
|
||||
reranker_results, str(Path(args.output_dir) / "reranker_comparison.txt")
|
||||
)
|
||||
|
||||
# Save combined results
|
||||
combined_results = {
|
||||
'retriever': retriever_results,
|
||||
'reranker': reranker_results,
|
||||
'overall_summary': {
|
||||
'retriever_status': retriever_results['summary']['status'],
|
||||
'reranker_status': reranker_results['summary']['status'],
|
||||
'avg_retriever_improvement': retriever_results['summary']['avg_improvement_pct'],
|
||||
'avg_reranker_improvement': reranker_results['summary']['avg_improvement_pct']
|
||||
}
|
||||
}
|
||||
|
||||
with open(Path(args.output_dir) / "combined_comparison.json", 'w') as f:
|
||||
json.dump(combined_results, f, indent=2, default=str)
|
||||
|
||||
logger.info(f"✅ Comparison completed! Results saved to: {args.output_dir}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,138 +0,0 @@
|
||||
import os
|
||||
import time
|
||||
import argparse
|
||||
import logging
|
||||
import torch
|
||||
import psutil
|
||||
import tracemalloc
|
||||
import numpy as np
|
||||
from utils.config_loader import get_config
|
||||
from models.bge_reranker import BGERerankerModel
|
||||
from data.dataset import BGERerankerDataset
|
||||
from transformers import AutoTokenizer
|
||||
from evaluation.metrics import compute_reranker_metrics
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger("compare_reranker")
|
||||
|
||||
def parse_args():
|
||||
"""Parse command-line arguments for reranker comparison."""
|
||||
parser = argparse.ArgumentParser(description="Compare fine-tuned and baseline reranker models.")
|
||||
parser.add_argument('--finetuned_model_path', type=str, required=True, help='Path to fine-tuned reranker model')
|
||||
parser.add_argument('--baseline_model_path', type=str, required=True, help='Path to baseline reranker model')
|
||||
parser.add_argument('--data_path', type=str, required=True, help='Path to evaluation data (JSONL)')
|
||||
parser.add_argument('--batch_size', type=int, default=16, help='Batch size for inference')
|
||||
parser.add_argument('--max_samples', type=int, default=1000, help='Max samples to benchmark (for speed)')
|
||||
parser.add_argument('--device', type=str, default='cuda', help='Device to use (cuda, cpu, npu)')
|
||||
parser.add_argument('--output', type=str, default='compare_reranker_results.txt', help='File to save comparison results')
|
||||
parser.add_argument('--k_values', type=int, nargs='+', default=[1, 5, 10], help='K values for metrics@k')
|
||||
return parser.parse_args()
|
||||
|
||||
def measure_memory():
|
||||
"""Measure the current memory usage of the process."""
|
||||
process = psutil.Process(os.getpid())
|
||||
return process.memory_info().rss / (1024 * 1024) # MB
|
||||
|
||||
def run_reranker(model_path, data_path, batch_size, max_samples, device):
|
||||
"""
|
||||
Run inference on a reranker model and measure throughput, latency, and memory.
|
||||
Args:
|
||||
model_path (str): Path to the reranker model.
|
||||
data_path (str): Path to the evaluation data.
|
||||
batch_size (int): Batch size for inference.
|
||||
max_samples (int): Maximum number of samples to process.
|
||||
device (str): Device to use (cuda, cpu, npu).
|
||||
Returns:
|
||||
tuple: (throughput, latency, mem_mb, all_scores, all_labels)
|
||||
"""
|
||||
model = BGERerankerModel(model_name_or_path=model_path, device=device)
|
||||
tokenizer = model.tokenizer if hasattr(model, 'tokenizer') and model.tokenizer else AutoTokenizer.from_pretrained(model_path)
|
||||
dataset = BGERerankerDataset(data_path, tokenizer, is_train=False)
|
||||
dataloader = DataLoader(dataset, batch_size=batch_size)
|
||||
model.eval()
|
||||
model.to(device)
|
||||
n_samples = 0
|
||||
times = []
|
||||
all_scores = []
|
||||
all_labels = []
|
||||
tracemalloc.start()
|
||||
with torch.no_grad():
|
||||
for batch in dataloader:
|
||||
if n_samples >= max_samples:
|
||||
break
|
||||
start = time.time()
|
||||
outputs = model.model(
|
||||
input_ids=batch['input_ids'].to(device),
|
||||
attention_mask=batch['attention_mask'].to(device)
|
||||
)
|
||||
torch.cuda.synchronize() if device.startswith('cuda') else None
|
||||
end = time.time()
|
||||
times.append(end - start)
|
||||
n = batch['input_ids'].shape[0]
|
||||
n_samples += n
|
||||
# For metrics: collect scores and labels
|
||||
scores = outputs.logits if hasattr(outputs, 'logits') else outputs[0]
|
||||
all_scores.append(scores.cpu().numpy())
|
||||
all_labels.append(batch['labels'].cpu().numpy())
|
||||
current, peak = tracemalloc.get_traced_memory()
|
||||
tracemalloc.stop()
|
||||
throughput = n_samples / sum(times)
|
||||
latency = np.mean(times) / batch_size
|
||||
mem_mb = peak / (1024 * 1024)
|
||||
# Prepare for metrics
|
||||
all_scores = np.concatenate(all_scores, axis=0)
|
||||
all_labels = np.concatenate(all_labels, axis=0)
|
||||
return throughput, latency, mem_mb, all_scores, all_labels
|
||||
|
||||
def main():
|
||||
"""Run reranker comparison with robust error handling and config-driven defaults."""
|
||||
parser = argparse.ArgumentParser(description="Compare reranker models.")
|
||||
parser.add_argument('--device', type=str, default=None, help='Device to use (cuda, npu, cpu)')
|
||||
parser.add_argument('--finetuned_model_path', type=str, required=True, help='Path to fine-tuned reranker model')
|
||||
parser.add_argument('--baseline_model_path', type=str, required=True, help='Path to baseline reranker model')
|
||||
parser.add_argument('--data_path', type=str, required=True, help='Path to evaluation data (JSONL)')
|
||||
parser.add_argument('--batch_size', type=int, default=16, help='Batch size for inference')
|
||||
parser.add_argument('--max_samples', type=int, default=1000, help='Max samples to benchmark (for speed)')
|
||||
parser.add_argument('--output', type=str, default='compare_reranker_results.txt', help='File to save comparison results')
|
||||
parser.add_argument('--k_values', type=int, nargs='+', default=[1, 5, 10], help='K values for metrics@k')
|
||||
args = parser.parse_args()
|
||||
# Load config and set defaults
|
||||
config = get_config()
|
||||
device = args.device or config.get('hardware', {}).get('device', 'cuda') # type: ignore[attr-defined]
|
||||
try:
|
||||
# Fine-tuned model
|
||||
logger.info("Running fine-tuned reranker...")
|
||||
ft_throughput, ft_latency, ft_mem, ft_scores, ft_labels = run_reranker(
|
||||
args.finetuned_model_path, args.data_path, args.batch_size, args.max_samples, device)
|
||||
ft_metrics = compute_reranker_metrics(ft_scores, ft_labels, k_values=args.k_values)
|
||||
# Baseline model
|
||||
logger.info("Running baseline reranker...")
|
||||
bl_throughput, bl_latency, bl_mem, bl_scores, bl_labels = run_reranker(
|
||||
args.baseline_model_path, args.data_path, args.batch_size, args.max_samples, device)
|
||||
bl_metrics = compute_reranker_metrics(bl_scores, bl_labels, k_values=args.k_values)
|
||||
# Output comparison
|
||||
with open(args.output, 'w') as f:
|
||||
f.write(f"Metric\tBaseline\tFine-tuned\tDelta\n")
|
||||
for k in args.k_values:
|
||||
for metric in [f'accuracy@{k}', f'map@{k}', f'mrr@{k}', f'ndcg@{k}']:
|
||||
bl = bl_metrics.get(metric, 0)
|
||||
ft = ft_metrics.get(metric, 0)
|
||||
delta = ft - bl
|
||||
f.write(f"{metric}\t{bl:.4f}\t{ft:.4f}\t{delta:+.4f}\n")
|
||||
for metric in ['map', 'mrr']:
|
||||
bl = bl_metrics.get(metric, 0)
|
||||
ft = ft_metrics.get(metric, 0)
|
||||
delta = ft - bl
|
||||
f.write(f"{metric}\t{bl:.4f}\t{ft:.4f}\t{delta:+.4f}\n")
|
||||
f.write(f"Throughput (samples/sec)\t{bl_throughput:.2f}\t{ft_throughput:.2f}\t{ft_throughput-bl_throughput:+.2f}\n")
|
||||
f.write(f"Latency (ms/sample)\t{bl_latency*1000:.2f}\t{ft_latency*1000:.2f}\t{(ft_latency-bl_latency)*1000:+.2f}\n")
|
||||
f.write(f"Peak memory (MB)\t{bl_mem:.2f}\t{ft_mem:.2f}\t{ft_mem-bl_mem:+.2f}\n")
|
||||
logger.info(f"Comparison results saved to {args.output}")
|
||||
print(f"\nComparison complete. See {args.output} for details.")
|
||||
except Exception as e:
|
||||
logging.error(f"Reranker comparison failed: {e}")
|
||||
raise
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
@@ -1,144 +0,0 @@
|
||||
import os
|
||||
import time
|
||||
import argparse
|
||||
import logging
|
||||
import torch
|
||||
import psutil
|
||||
import tracemalloc
|
||||
import numpy as np
|
||||
from utils.config_loader import get_config
|
||||
from models.bge_m3 import BGEM3Model
|
||||
from data.dataset import BGEM3Dataset
|
||||
from transformers import AutoTokenizer
|
||||
from evaluation.metrics import compute_retrieval_metrics
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger("compare_retriever")
|
||||
|
||||
def parse_args():
|
||||
"""Parse command line arguments for retriever comparison."""
|
||||
parser = argparse.ArgumentParser(description="Compare fine-tuned and baseline retriever models.")
|
||||
parser.add_argument('--finetuned_model_path', type=str, required=True, help='Path to fine-tuned retriever model')
|
||||
parser.add_argument('--baseline_model_path', type=str, required=True, help='Path to baseline retriever model')
|
||||
parser.add_argument('--data_path', type=str, required=True, help='Path to evaluation data (JSONL)')
|
||||
parser.add_argument('--batch_size', type=int, default=16, help='Batch size for inference')
|
||||
parser.add_argument('--max_samples', type=int, default=1000, help='Max samples to benchmark (for speed)')
|
||||
parser.add_argument('--device', type=str, default='cuda', help='Device to use (cuda, cpu, npu)')
|
||||
parser.add_argument('--output', type=str, default='compare_retriever_results.txt', help='File to save comparison results')
|
||||
parser.add_argument('--k_values', type=int, nargs='+', default=[1, 5, 10, 20, 100], help='K values for metrics@k')
|
||||
return parser.parse_args()
|
||||
|
||||
def measure_memory():
|
||||
"""Measure the current process memory usage in MB."""
|
||||
process = psutil.Process(os.getpid())
|
||||
return process.memory_info().rss / (1024 * 1024) # MB
|
||||
|
||||
def run_retriever(model_path, data_path, batch_size, max_samples, device):
|
||||
"""
|
||||
Run inference on a retriever model and measure throughput, latency, and memory.
|
||||
Args:
|
||||
model_path (str): Path to the retriever model.
|
||||
data_path (str): Path to the evaluation data.
|
||||
batch_size (int): Batch size for inference.
|
||||
max_samples (int): Maximum number of samples to process.
|
||||
device (str): Device to use (cuda, cpu, npu).
|
||||
Returns:
|
||||
tuple: throughput (samples/sec), latency (ms/sample), peak memory (MB),
|
||||
query_embeds (np.ndarray), passage_embeds (np.ndarray), labels (np.ndarray).
|
||||
"""
|
||||
model = BGEM3Model(model_name_or_path=model_path, device=device)
|
||||
tokenizer = model.tokenizer if hasattr(model, 'tokenizer') and model.tokenizer else AutoTokenizer.from_pretrained(model_path)
|
||||
dataset = BGEM3Dataset(data_path, tokenizer, is_train=False)
|
||||
dataloader = DataLoader(dataset, batch_size=batch_size)
|
||||
model.eval()
|
||||
model.to(device)
|
||||
n_samples = 0
|
||||
times = []
|
||||
query_embeds = []
|
||||
passage_embeds = []
|
||||
labels = []
|
||||
tracemalloc.start()
|
||||
with torch.no_grad():
|
||||
for batch in dataloader:
|
||||
if n_samples >= max_samples:
|
||||
break
|
||||
start = time.time()
|
||||
out = model(
|
||||
batch['query_input_ids'].to(device),
|
||||
batch['query_attention_mask'].to(device),
|
||||
return_dense=True, return_sparse=False, return_colbert=False
|
||||
)
|
||||
torch.cuda.synchronize() if device.startswith('cuda') else None
|
||||
end = time.time()
|
||||
times.append(end - start)
|
||||
n = batch['query_input_ids'].shape[0]
|
||||
n_samples += n
|
||||
# For metrics: collect embeddings and labels
|
||||
query_embeds.append(out['query_embeds'].cpu().numpy())
|
||||
passage_embeds.append(out['passage_embeds'].cpu().numpy())
|
||||
labels.append(batch['labels'].cpu().numpy())
|
||||
current, peak = tracemalloc.get_traced_memory()
|
||||
tracemalloc.stop()
|
||||
throughput = n_samples / sum(times)
|
||||
latency = np.mean(times) / batch_size
|
||||
mem_mb = peak / (1024 * 1024)
|
||||
# Prepare for metrics
|
||||
query_embeds = np.concatenate(query_embeds, axis=0)
|
||||
passage_embeds = np.concatenate(passage_embeds, axis=0)
|
||||
labels = np.concatenate(labels, axis=0)
|
||||
return throughput, latency, mem_mb, query_embeds, passage_embeds, labels
|
||||
|
||||
def main():
|
||||
"""Run retriever comparison with robust error handling and config-driven defaults."""
|
||||
import argparse
|
||||
parser = argparse.ArgumentParser(description="Compare retriever models.")
|
||||
parser.add_argument('--device', type=str, default=None, help='Device to use (cuda, npu, cpu)')
|
||||
parser.add_argument('--finetuned_model_path', type=str, required=True, help='Path to fine-tuned retriever model')
|
||||
parser.add_argument('--baseline_model_path', type=str, required=True, help='Path to baseline retriever model')
|
||||
parser.add_argument('--data_path', type=str, required=True, help='Path to evaluation data (JSONL)')
|
||||
parser.add_argument('--batch_size', type=int, default=16, help='Batch size for inference')
|
||||
parser.add_argument('--max_samples', type=int, default=1000, help='Max samples to benchmark (for speed)')
|
||||
parser.add_argument('--output', type=str, default='compare_retriever_results.txt', help='File to save comparison results')
|
||||
parser.add_argument('--k_values', type=int, nargs='+', default=[1, 5, 10, 20, 100], help='K values for metrics@k')
|
||||
args = parser.parse_args()
|
||||
# Load config and set defaults
|
||||
from utils.config_loader import get_config
|
||||
config = get_config()
|
||||
device = args.device or config.get('hardware', {}).get('device', 'cuda') # type: ignore[attr-defined]
|
||||
try:
|
||||
# Fine-tuned model
|
||||
logger.info("Running fine-tuned retriever...")
|
||||
ft_throughput, ft_latency, ft_mem, ft_qe, ft_pe, ft_labels = run_retriever(
|
||||
args.finetuned_model_path, args.data_path, args.batch_size, args.max_samples, device)
|
||||
ft_metrics = compute_retrieval_metrics(ft_qe, ft_pe, ft_labels, k_values=args.k_values)
|
||||
# Baseline model
|
||||
logger.info("Running baseline retriever...")
|
||||
bl_throughput, bl_latency, bl_mem, bl_qe, bl_pe, bl_labels = run_retriever(
|
||||
args.baseline_model_path, args.data_path, args.batch_size, args.max_samples, device)
|
||||
bl_metrics = compute_retrieval_metrics(bl_qe, bl_pe, bl_labels, k_values=args.k_values)
|
||||
# Output comparison
|
||||
with open(args.output, 'w') as f:
|
||||
f.write(f"Metric\tBaseline\tFine-tuned\tDelta\n")
|
||||
for k in args.k_values:
|
||||
for metric in [f'recall@{k}', f'precision@{k}', f'map@{k}', f'mrr@{k}', f'ndcg@{k}']:
|
||||
bl = bl_metrics.get(metric, 0)
|
||||
ft = ft_metrics.get(metric, 0)
|
||||
delta = ft - bl
|
||||
f.write(f"{metric}\t{bl:.4f}\t{ft:.4f}\t{delta:+.4f}\n")
|
||||
for metric in ['map', 'mrr']:
|
||||
bl = bl_metrics.get(metric, 0)
|
||||
ft = ft_metrics.get(metric, 0)
|
||||
delta = ft - bl
|
||||
f.write(f"{metric}\t{bl:.4f}\t{ft:.4f}\t{delta:+.4f}\n")
|
||||
f.write(f"Throughput (samples/sec)\t{bl_throughput:.2f}\t{ft_throughput:.2f}\t{ft_throughput-bl_throughput:+.2f}\n")
|
||||
f.write(f"Latency (ms/sample)\t{bl_latency*1000:.2f}\t{ft_latency*1000:.2f}\t{(ft_latency-bl_latency)*1000:+.2f}\n")
|
||||
f.write(f"Peak memory (MB)\t{bl_mem:.2f}\t{ft_mem:.2f}\t{ft_mem-bl_mem:+.2f}\n")
|
||||
logger.info(f"Comparison results saved to {args.output}")
|
||||
print(f"\nComparison complete. See {args.output} for details.")
|
||||
except Exception as e:
|
||||
logging.error(f"Retriever comparison failed: {e}")
|
||||
raise
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
40
scripts/concat_jsonl.py
Normal file
40
scripts/concat_jsonl.py
Normal file
@@ -0,0 +1,40 @@
|
||||
"""
|
||||
Concatenate all JSONL files in a directory into a single JSONL file.
|
||||
"""
|
||||
import argparse
|
||||
import pathlib
|
||||
|
||||
|
||||
def concat_jsonl_files(input_dir: pathlib.Path, output_file: pathlib.Path):
|
||||
"""
|
||||
Read all .jsonl files in the input directory, concatenate their lines,
|
||||
and write them to the output file.
|
||||
|
||||
Args:
|
||||
input_dir: Path to directory containing .jsonl files.
|
||||
output_file: Path to the resulting concatenated .jsonl file.
|
||||
"""
|
||||
jsonl_files = sorted(input_dir.glob('*.jsonl'))
|
||||
if not jsonl_files:
|
||||
print(f"No JSONL files found in {input_dir}")
|
||||
return
|
||||
|
||||
with output_file.open('w', encoding='utf-8') as fout:
|
||||
for jsonl in jsonl_files:
|
||||
with jsonl.open('r', encoding='utf-8') as fin:
|
||||
for line in fin:
|
||||
if line.strip(): # skip empty lines
|
||||
fout.write(line)
|
||||
print(f"Concatenated {len(jsonl_files)} files into {output_file}")
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Concatenate JSONL files.")
|
||||
parser.add_argument('input_dir', type=pathlib.Path, help='Directory with JSONL files')
|
||||
parser.add_argument('output_file', type=pathlib.Path, help='Output JSONL file')
|
||||
args = parser.parse_args()
|
||||
concat_jsonl_files(args.input_dir, args.output_file)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
145
scripts/convert_reranker_dataset.py
Normal file
145
scripts/convert_reranker_dataset.py
Normal file
@@ -0,0 +1,145 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
convert_jsonl_with_fallback.py
|
||||
|
||||
Like your original converter, but when inner JSON parsing fails
|
||||
(it often does when there are unescaped quotes inside the passage),
|
||||
we do a simple regex/string‐search to pull out passage, label, score.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
def parse_args():
|
||||
p = argparse.ArgumentParser(
|
||||
description="Convert JSONL with embedded JSON in 'data' to flat JSONL, with fallback."
|
||||
)
|
||||
p.add_argument(
|
||||
"--input", "-i", required=True,
|
||||
help="Path to input .jsonl file or directory containing .jsonl files"
|
||||
)
|
||||
p.add_argument(
|
||||
"--output-dir", "-o", required=True,
|
||||
help="Directory where converted files will be written"
|
||||
)
|
||||
return p.parse_args()
|
||||
|
||||
def strip_code_fence(s: str) -> str:
|
||||
"""Remove ```json ... ``` or ``` ... ``` fences, return inner text."""
|
||||
fenced = re.compile(r"```(?:json)?\s*(.*?)\s*```", re.DOTALL)
|
||||
m = fenced.search(s)
|
||||
return m.group(1) if m else s
|
||||
|
||||
def fallback_extract(inner: str):
|
||||
"""
|
||||
Fallback extractor that:
|
||||
- finds the passage between "passage": " … ",
|
||||
- then pulls label and score via regex.
|
||||
Returns dict or None.
|
||||
"""
|
||||
# 1) passage: find between `"passage": "` and the next `",`
|
||||
p_start = inner.find('"passage": "')
|
||||
if p_start < 0:
|
||||
return None
|
||||
p_start += len('"passage": "')
|
||||
# we assume the passage ends at the first occurrence of `",` after p_start
|
||||
p_end = inner.find('",', p_start)
|
||||
if p_end < 0:
|
||||
return None
|
||||
passage = inner[p_start:p_end]
|
||||
|
||||
# 2) label
|
||||
m_lbl = re.search(r'"label"\s*:\s*(\d+)', inner[p_end:])
|
||||
if not m_lbl:
|
||||
return None
|
||||
label = int(m_lbl.group(1))
|
||||
|
||||
# 3) score
|
||||
m_sc = re.search(r'"score"\s*:\s*([0-9]*\.?[0-9]+)', inner[p_end:])
|
||||
if not m_sc:
|
||||
return None
|
||||
score = float(m_sc.group(1))
|
||||
|
||||
return {"passage": passage, "label": label, "score": score}
|
||||
|
||||
def process_file(in_path: Path, out_path: Path):
|
||||
with in_path.open("r", encoding="utf-8") as fin, \
|
||||
out_path.open("w", encoding="utf-8") as fout:
|
||||
|
||||
for lineno, line in enumerate(fin, 1):
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
|
||||
# parse the outer JSON
|
||||
try:
|
||||
outer = json.loads(line)
|
||||
except json.JSONDecodeError:
|
||||
print(f"[WARN] line {lineno} in {in_path.name}: outer JSON invalid, skipping", file=sys.stderr)
|
||||
continue
|
||||
|
||||
query = outer.get("query")
|
||||
data_field = outer.get("data")
|
||||
if not query or not data_field:
|
||||
# missing query or data → skip
|
||||
continue
|
||||
|
||||
inner_text = strip_code_fence(data_field)
|
||||
|
||||
# attempt normal JSON parse
|
||||
record = None
|
||||
try:
|
||||
inner = json.loads(inner_text)
|
||||
# must have all three keys
|
||||
if all(k in inner for k in ("passage", "label", "score")):
|
||||
record = {
|
||||
"query": query,
|
||||
"passage": inner["passage"],
|
||||
"label": inner["label"],
|
||||
"score": inner["score"]
|
||||
}
|
||||
# else record stays None → fallback below
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
# if JSON parse gave nothing, try fallback
|
||||
if record is None:
|
||||
fb = fallback_extract(inner_text)
|
||||
if fb:
|
||||
record = {
|
||||
"query": query,
|
||||
**fb
|
||||
}
|
||||
print(f"[INFO] line {lineno} in {in_path.name}: used fallback extraction", file=sys.stderr)
|
||||
else:
|
||||
print(f"[WARN] line {lineno} in {in_path.name}: unable to extract fields, skipping", file=sys.stderr)
|
||||
continue
|
||||
|
||||
fout.write(json.dumps(record, ensure_ascii=False) + "\n")
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
inp = Path(args.input)
|
||||
outdir = Path(args.output_dir)
|
||||
outdir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if inp.is_dir():
|
||||
files = list(inp.glob("*.jsonl"))
|
||||
else:
|
||||
files = [inp]
|
||||
|
||||
if not files:
|
||||
print(f"No .jsonl files found in {inp}", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
for f in files:
|
||||
dest = outdir / f.name
|
||||
print(f"Converting {f} → {dest}", file=sys.stderr)
|
||||
process_file(f, dest)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,584 +0,0 @@
|
||||
"""
|
||||
Evaluation script for fine-tuned BGE models
|
||||
"""
|
||||
import os
|
||||
import sys
|
||||
import argparse
|
||||
import logging
|
||||
import json
|
||||
import torch
|
||||
from transformers import AutoTokenizer
|
||||
import numpy as np
|
||||
|
||||
# Add parent directory to path
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from models.bge_m3 import BGEM3Model
|
||||
from models.bge_reranker import BGERerankerModel
|
||||
from evaluation.evaluator import RetrievalEvaluator, InteractiveEvaluator
|
||||
from data.preprocessing import DataPreprocessor
|
||||
from utils.logging import setup_logging, get_logger
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def parse_args():
|
||||
"""Parse command line arguments for evaluation."""
|
||||
parser = argparse.ArgumentParser(description="Evaluate fine-tuned BGE models")
|
||||
|
||||
# Model arguments
|
||||
parser.add_argument("--retriever_model", type=str, required=True,
|
||||
help="Path to fine-tuned retriever model")
|
||||
parser.add_argument("--reranker_model", type=str, default=None,
|
||||
help="Path to fine-tuned reranker model")
|
||||
parser.add_argument("--base_retriever", type=str, default="BAAI/bge-m3",
|
||||
help="Base retriever model for comparison")
|
||||
parser.add_argument("--base_reranker", type=str, default="BAAI/bge-reranker-base",
|
||||
help="Base reranker model for comparison")
|
||||
|
||||
# Data arguments
|
||||
parser.add_argument("--eval_data", type=str, required=True,
|
||||
help="Path to evaluation data file")
|
||||
parser.add_argument("--corpus_data", type=str, default=None,
|
||||
help="Path to corpus data for retrieval evaluation")
|
||||
parser.add_argument("--data_format", type=str, default="auto",
|
||||
choices=["auto", "jsonl", "json", "csv", "tsv", "dureader", "msmarco"],
|
||||
help="Format of input data")
|
||||
|
||||
# Evaluation arguments
|
||||
parser.add_argument("--output_dir", type=str, default="./evaluation_results",
|
||||
help="Directory to save evaluation results")
|
||||
parser.add_argument("--batch_size", type=int, default=32,
|
||||
help="Batch size for evaluation")
|
||||
parser.add_argument("--max_length", type=int, default=512,
|
||||
help="Maximum sequence length")
|
||||
parser.add_argument("--k_values", type=int, nargs="+", default=[1, 5, 10, 20, 100],
|
||||
help="K values for evaluation metrics")
|
||||
|
||||
# Evaluation modes
|
||||
parser.add_argument("--eval_retriever", action="store_true",
|
||||
help="Evaluate retriever model")
|
||||
parser.add_argument("--eval_reranker", action="store_true",
|
||||
help="Evaluate reranker model")
|
||||
parser.add_argument("--eval_pipeline", action="store_true",
|
||||
help="Evaluate full retrieval pipeline")
|
||||
parser.add_argument("--compare_baseline", action="store_true",
|
||||
help="Compare with baseline models")
|
||||
parser.add_argument("--interactive", action="store_true",
|
||||
help="Run interactive evaluation")
|
||||
|
||||
# Pipeline settings
|
||||
parser.add_argument("--retrieval_top_k", type=int, default=100,
|
||||
help="Top-k for initial retrieval")
|
||||
parser.add_argument("--rerank_top_k", type=int, default=10,
|
||||
help="Top-k for reranking")
|
||||
|
||||
# Other
|
||||
parser.add_argument("--device", type=str, default="auto",
|
||||
help="Device to use (auto, cpu, cuda)")
|
||||
parser.add_argument("--seed", type=int, default=42,
|
||||
help="Random seed")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Validation
|
||||
if not any([args.eval_retriever, args.eval_reranker, args.eval_pipeline, args.interactive]):
|
||||
args.eval_pipeline = True # Default to pipeline evaluation
|
||||
|
||||
return args
|
||||
|
||||
|
||||
def load_evaluation_data(data_path: str, data_format: str):
|
||||
"""Load evaluation data from a file."""
|
||||
preprocessor = DataPreprocessor()
|
||||
|
||||
if data_format == "auto":
|
||||
data_format = preprocessor._detect_format(data_path)
|
||||
|
||||
if data_format == "jsonl":
|
||||
data = preprocessor._load_jsonl(data_path)
|
||||
elif data_format == "json":
|
||||
data = preprocessor._load_json(data_path)
|
||||
elif data_format in ["csv", "tsv"]:
|
||||
data = preprocessor._load_csv(data_path, delimiter="," if data_format == "csv" else "\t")
|
||||
elif data_format == "dureader":
|
||||
data = preprocessor._load_dureader(data_path)
|
||||
elif data_format == "msmarco":
|
||||
data = preprocessor._load_msmarco(data_path)
|
||||
else:
|
||||
raise ValueError(f"Unsupported format: {data_format}")
|
||||
|
||||
return data
|
||||
|
||||
|
||||
def evaluate_retriever(args, retriever_model, tokenizer, eval_data):
|
||||
"""Evaluate the retriever model."""
|
||||
logger.info("Evaluating retriever model...")
|
||||
|
||||
evaluator = RetrievalEvaluator(
|
||||
model=retriever_model,
|
||||
tokenizer=tokenizer,
|
||||
device=args.device,
|
||||
k_values=args.k_values
|
||||
)
|
||||
|
||||
# Prepare data for evaluation
|
||||
queries = [item['query'] for item in eval_data]
|
||||
|
||||
# If we have passage data, create passage corpus
|
||||
all_passages = set()
|
||||
for item in eval_data:
|
||||
if 'pos' in item:
|
||||
all_passages.update(item['pos'])
|
||||
if 'neg' in item:
|
||||
all_passages.update(item['neg'])
|
||||
|
||||
passages = list(all_passages)
|
||||
|
||||
# Create relevance mapping
|
||||
relevant_docs = {}
|
||||
for i, item in enumerate(eval_data):
|
||||
query = item['query']
|
||||
if 'pos' in item:
|
||||
relevant_indices = []
|
||||
for pos in item['pos']:
|
||||
if pos in passages:
|
||||
relevant_indices.append(passages.index(pos))
|
||||
if relevant_indices:
|
||||
relevant_docs[query] = relevant_indices
|
||||
|
||||
# Evaluate
|
||||
if passages and relevant_docs:
|
||||
metrics = evaluator.evaluate_retrieval_pipeline(
|
||||
queries=queries,
|
||||
corpus=passages,
|
||||
relevant_docs=relevant_docs,
|
||||
batch_size=args.batch_size
|
||||
)
|
||||
else:
|
||||
logger.warning("No valid corpus or relevance data found. Running basic evaluation.")
|
||||
# Create dummy evaluation
|
||||
metrics = {"note": "Limited evaluation due to data format"}
|
||||
|
||||
return metrics
|
||||
|
||||
|
||||
def evaluate_reranker(args, reranker_model, tokenizer, eval_data):
|
||||
"""Evaluate the reranker model."""
|
||||
logger.info("Evaluating reranker model...")
|
||||
|
||||
evaluator = RetrievalEvaluator(
|
||||
model=reranker_model,
|
||||
tokenizer=tokenizer,
|
||||
device=args.device,
|
||||
k_values=args.k_values
|
||||
)
|
||||
|
||||
# Prepare data for reranker evaluation
|
||||
total_queries = 0
|
||||
total_correct = 0
|
||||
total_mrr = 0
|
||||
|
||||
for item in eval_data:
|
||||
if 'pos' not in item or 'neg' not in item:
|
||||
continue
|
||||
|
||||
query = item['query']
|
||||
positives = item['pos']
|
||||
negatives = item['neg']
|
||||
|
||||
if not positives or not negatives:
|
||||
continue
|
||||
|
||||
# Create candidates (1 positive + negatives)
|
||||
candidates = positives[:1] + negatives
|
||||
|
||||
# Score all candidates
|
||||
pairs = [(query, candidate) for candidate in candidates]
|
||||
scores = reranker_model.compute_score(pairs, batch_size=args.batch_size)
|
||||
|
||||
# Check if positive is ranked first
|
||||
best_idx = np.argmax(scores)
|
||||
if best_idx == 0: # First is positive
|
||||
total_correct += 1
|
||||
total_mrr += 1.0
|
||||
else:
|
||||
# Find rank of positive (it's always at index 0)
|
||||
sorted_indices = np.argsort(scores)[::-1]
|
||||
positive_rank = np.where(sorted_indices == 0)[0][0] + 1
|
||||
total_mrr += 1.0 / positive_rank
|
||||
|
||||
total_queries += 1
|
||||
|
||||
if total_queries > 0:
|
||||
metrics = {
|
||||
'accuracy@1': total_correct / total_queries,
|
||||
'mrr': total_mrr / total_queries,
|
||||
'total_queries': total_queries
|
||||
}
|
||||
else:
|
||||
metrics = {"note": "No valid query-passage pairs found for reranker evaluation"}
|
||||
|
||||
return metrics
|
||||
|
||||
|
||||
def evaluate_pipeline(args, retriever_model, reranker_model, tokenizer, eval_data):
|
||||
"""Evaluate the full retrieval + reranking pipeline."""
|
||||
logger.info("Evaluating full pipeline...")
|
||||
|
||||
# Prepare corpus
|
||||
all_passages = set()
|
||||
for item in eval_data:
|
||||
if 'pos' in item:
|
||||
all_passages.update(item['pos'])
|
||||
if 'neg' in item:
|
||||
all_passages.update(item['neg'])
|
||||
|
||||
passages = list(all_passages)
|
||||
|
||||
if not passages:
|
||||
return {"note": "No passages found for pipeline evaluation"}
|
||||
|
||||
# Pre-encode corpus with retriever
|
||||
logger.info("Encoding corpus...")
|
||||
corpus_embeddings = retriever_model.encode(
|
||||
passages,
|
||||
batch_size=args.batch_size,
|
||||
return_numpy=True,
|
||||
device=args.device
|
||||
)
|
||||
|
||||
total_queries = 0
|
||||
pipeline_metrics = {
|
||||
'retrieval_recall@10': [],
|
||||
'rerank_accuracy@1': [],
|
||||
'rerank_mrr': []
|
||||
}
|
||||
|
||||
for item in eval_data:
|
||||
if 'pos' not in item or not item['pos']:
|
||||
continue
|
||||
|
||||
query = item['query']
|
||||
relevant_passages = item['pos']
|
||||
|
||||
# Step 1: Retrieval
|
||||
query_embedding = retriever_model.encode(
|
||||
[query],
|
||||
return_numpy=True,
|
||||
device=args.device
|
||||
)
|
||||
|
||||
# Compute similarities
|
||||
similarities = np.dot(query_embedding, corpus_embeddings.T).squeeze()
|
||||
|
||||
# Get top-k for retrieval
|
||||
top_k_indices = np.argsort(similarities)[::-1][:args.retrieval_top_k]
|
||||
retrieved_passages = [passages[i] for i in top_k_indices]
|
||||
|
||||
# Check retrieval recall
|
||||
retrieved_relevant = sum(1 for p in retrieved_passages[:10] if p in relevant_passages)
|
||||
retrieval_recall = retrieved_relevant / min(len(relevant_passages), 10)
|
||||
pipeline_metrics['retrieval_recall@10'].append(retrieval_recall)
|
||||
|
||||
# Step 2: Reranking (if reranker available)
|
||||
if reranker_model is not None:
|
||||
# Rerank top candidates
|
||||
rerank_candidates = retrieved_passages[:args.rerank_top_k]
|
||||
pairs = [(query, candidate) for candidate in rerank_candidates]
|
||||
|
||||
if pairs:
|
||||
rerank_scores = reranker_model.compute_score(pairs, batch_size=args.batch_size)
|
||||
reranked_indices = np.argsort(rerank_scores)[::-1]
|
||||
reranked_passages = [rerank_candidates[i] for i in reranked_indices]
|
||||
|
||||
# Check reranking performance
|
||||
if reranked_passages[0] in relevant_passages:
|
||||
pipeline_metrics['rerank_accuracy@1'].append(1.0)
|
||||
pipeline_metrics['rerank_mrr'].append(1.0)
|
||||
else:
|
||||
pipeline_metrics['rerank_accuracy@1'].append(0.0)
|
||||
# Find MRR
|
||||
for rank, passage in enumerate(reranked_passages, 1):
|
||||
if passage in relevant_passages:
|
||||
pipeline_metrics['rerank_mrr'].append(1.0 / rank)
|
||||
break
|
||||
else:
|
||||
pipeline_metrics['rerank_mrr'].append(0.0)
|
||||
|
||||
total_queries += 1
|
||||
|
||||
# Average metrics
|
||||
final_metrics = {}
|
||||
for metric, values in pipeline_metrics.items():
|
||||
if values:
|
||||
final_metrics[metric] = np.mean(values)
|
||||
|
||||
final_metrics['total_queries'] = total_queries
|
||||
|
||||
return final_metrics
|
||||
|
||||
|
||||
def run_interactive_evaluation(args, retriever_model, reranker_model, tokenizer, eval_data):
|
||||
"""Run interactive evaluation."""
|
||||
logger.info("Starting interactive evaluation...")
|
||||
|
||||
# Prepare corpus
|
||||
all_passages = set()
|
||||
for item in eval_data:
|
||||
if 'pos' in item:
|
||||
all_passages.update(item['pos'])
|
||||
if 'neg' in item:
|
||||
all_passages.update(item['neg'])
|
||||
|
||||
passages = list(all_passages)
|
||||
|
||||
if not passages:
|
||||
print("No passages found for interactive evaluation")
|
||||
return
|
||||
|
||||
# Create interactive evaluator
|
||||
evaluator = InteractiveEvaluator(
|
||||
retriever=retriever_model,
|
||||
reranker=reranker_model,
|
||||
tokenizer=tokenizer,
|
||||
corpus=passages,
|
||||
device=args.device
|
||||
)
|
||||
|
||||
print("\n=== Interactive Evaluation ===")
|
||||
print("Enter queries to test the retrieval system.")
|
||||
print("Type 'quit' to exit, 'sample' to test with sample queries.")
|
||||
|
||||
# Load some sample queries
|
||||
sample_queries = [item['query'] for item in eval_data[:5]]
|
||||
|
||||
while True:
|
||||
try:
|
||||
user_input = input("\nEnter query: ").strip()
|
||||
|
||||
if user_input.lower() in ['quit', 'exit', 'q']:
|
||||
break
|
||||
elif user_input.lower() == 'sample':
|
||||
print("\nSample queries:")
|
||||
for i, query in enumerate(sample_queries):
|
||||
print(f"{i + 1}. {query}")
|
||||
continue
|
||||
elif user_input.isdigit() and 1 <= int(user_input) <= len(sample_queries):
|
||||
query = sample_queries[int(user_input) - 1]
|
||||
else:
|
||||
query = user_input
|
||||
|
||||
if not query:
|
||||
continue
|
||||
|
||||
# Search
|
||||
results = evaluator.search(query, top_k=10, rerank=reranker_model is not None)
|
||||
|
||||
print(f"\nResults for: {query}")
|
||||
print("-" * 50)
|
||||
for i, (idx, score, text) in enumerate(results):
|
||||
print(f"{i + 1}. (Score: {score:.3f}) {text[:100]}...")
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\nExiting interactive evaluation...")
|
||||
break
|
||||
except Exception as e:
|
||||
print(f"Error: {e}")
|
||||
|
||||
|
||||
def main():
|
||||
"""Run evaluation with robust error handling and config-driven defaults."""
|
||||
import argparse
|
||||
parser = argparse.ArgumentParser(description="Evaluate BGE models.")
|
||||
parser.add_argument('--device', type=str, default=None, help='Device to use (cuda, npu, cpu)')
|
||||
|
||||
# Model arguments
|
||||
parser.add_argument("--retriever_model", type=str, required=True,
|
||||
help="Path to fine-tuned retriever model")
|
||||
parser.add_argument("--reranker_model", type=str, default=None,
|
||||
help="Path to fine-tuned reranker model")
|
||||
parser.add_argument("--base_retriever", type=str, default="BAAI/bge-m3",
|
||||
help="Base retriever model for comparison")
|
||||
parser.add_argument("--base_reranker", type=str, default="BAAI/bge-reranker-base",
|
||||
help="Base reranker model for comparison")
|
||||
|
||||
# Data arguments
|
||||
parser.add_argument("--eval_data", type=str, required=True,
|
||||
help="Path to evaluation data file")
|
||||
parser.add_argument("--corpus_data", type=str, default=None,
|
||||
help="Path to corpus data for retrieval evaluation")
|
||||
parser.add_argument("--data_format", type=str, default="auto",
|
||||
choices=["auto", "jsonl", "json", "csv", "tsv", "dureader", "msmarco"],
|
||||
help="Format of input data")
|
||||
|
||||
# Evaluation arguments
|
||||
parser.add_argument("--output_dir", type=str, default="./evaluation_results",
|
||||
help="Directory to save evaluation results")
|
||||
parser.add_argument("--batch_size", type=int, default=32,
|
||||
help="Batch size for evaluation")
|
||||
parser.add_argument("--max_length", type=int, default=512,
|
||||
help="Maximum sequence length")
|
||||
parser.add_argument("--k_values", type=int, nargs="+", default=[1, 5, 10, 20, 100],
|
||||
help="K values for evaluation metrics")
|
||||
|
||||
# Evaluation modes
|
||||
parser.add_argument("--eval_retriever", action="store_true",
|
||||
help="Evaluate retriever model")
|
||||
parser.add_argument("--eval_reranker", action="store_true",
|
||||
help="Evaluate reranker model")
|
||||
parser.add_argument("--eval_pipeline", action="store_true",
|
||||
help="Evaluate full retrieval pipeline")
|
||||
parser.add_argument("--compare_baseline", action="store_true",
|
||||
help="Compare with baseline models")
|
||||
parser.add_argument("--interactive", action="store_true",
|
||||
help="Run interactive evaluation")
|
||||
|
||||
# Pipeline settings
|
||||
parser.add_argument("--retrieval_top_k", type=int, default=100,
|
||||
help="Top-k for initial retrieval")
|
||||
parser.add_argument("--rerank_top_k", type=int, default=10,
|
||||
help="Top-k for reranking")
|
||||
|
||||
# Other
|
||||
parser.add_argument("--seed", type=int, default=42,
|
||||
help="Random seed")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Load config and set defaults
|
||||
from utils.config_loader import get_config
|
||||
config = get_config()
|
||||
device = args.device or config.get('hardware', {}).get('device', 'cuda') # type: ignore[attr-defined]
|
||||
|
||||
# Setup logging
|
||||
setup_logging(
|
||||
output_dir=args.output_dir,
|
||||
log_level=logging.INFO,
|
||||
log_to_console=True,
|
||||
log_to_file=True
|
||||
)
|
||||
|
||||
logger = get_logger(__name__)
|
||||
logger.info("Starting model evaluation")
|
||||
logger.info(f"Arguments: {args}")
|
||||
logger.info(f"Using device: {device}")
|
||||
|
||||
# Set device
|
||||
if device == "auto":
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
elif device == "npu":
|
||||
npu = getattr(torch, "npu", None)
|
||||
if npu is not None and hasattr(npu, "is_available") and npu.is_available():
|
||||
device = "npu"
|
||||
else:
|
||||
logger.warning("NPU (Ascend) is not available. Falling back to CPU.")
|
||||
device = "cpu"
|
||||
else:
|
||||
device = device
|
||||
|
||||
logger.info(f"Using device: {device}")
|
||||
|
||||
# Load models
|
||||
logger.info("Loading models...")
|
||||
from utils.config_loader import get_config
|
||||
model_source = get_config().get('model_paths.source', 'huggingface')
|
||||
# Load retriever
|
||||
try:
|
||||
retriever_model = BGEM3Model.from_pretrained(args.retriever_model)
|
||||
retriever_tokenizer = None # Already loaded inside BGEM3Model if needed
|
||||
retriever_model.to(device)
|
||||
retriever_model.eval()
|
||||
logger.info(f"Loaded retriever from {args.retriever_model} (source: {model_source})")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load retriever: {e}")
|
||||
return
|
||||
|
||||
# Load reranker (optional)
|
||||
reranker_model = None
|
||||
reranker_tokenizer = None
|
||||
if args.reranker_model:
|
||||
try:
|
||||
reranker_model = BGERerankerModel(args.reranker_model)
|
||||
reranker_tokenizer = None # Already loaded inside BGERerankerModel if needed
|
||||
reranker_model.to(device)
|
||||
reranker_model.eval()
|
||||
logger.info(f"Loaded reranker from {args.reranker_model} (source: {model_source})")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to load reranker: {e}")
|
||||
|
||||
# Load evaluation data
|
||||
logger.info("Loading evaluation data...")
|
||||
eval_data = load_evaluation_data(args.eval_data, args.data_format)
|
||||
logger.info(f"Loaded {len(eval_data)} evaluation examples")
|
||||
|
||||
# Create output directory
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
|
||||
# Run evaluations
|
||||
all_results = {}
|
||||
|
||||
if args.eval_retriever:
|
||||
retriever_metrics = evaluate_retriever(args, retriever_model, retriever_tokenizer, eval_data)
|
||||
all_results['retriever'] = retriever_metrics
|
||||
logger.info(f"Retriever metrics: {retriever_metrics}")
|
||||
|
||||
if args.eval_reranker and reranker_model:
|
||||
reranker_metrics = evaluate_reranker(args, reranker_model, reranker_tokenizer, eval_data)
|
||||
all_results['reranker'] = reranker_metrics
|
||||
logger.info(f"Reranker metrics: {reranker_metrics}")
|
||||
|
||||
if args.eval_pipeline:
|
||||
pipeline_metrics = evaluate_pipeline(args, retriever_model, reranker_model, retriever_tokenizer, eval_data)
|
||||
all_results['pipeline'] = pipeline_metrics
|
||||
logger.info(f"Pipeline metrics: {pipeline_metrics}")
|
||||
|
||||
# Compare with baseline if requested
|
||||
if args.compare_baseline:
|
||||
logger.info("Loading baseline models for comparison...")
|
||||
try:
|
||||
base_retriever = BGEM3Model(args.base_retriever)
|
||||
base_retriever.to(device)
|
||||
base_retriever.eval()
|
||||
|
||||
base_retriever_metrics = evaluate_retriever(args, base_retriever, retriever_tokenizer, eval_data)
|
||||
all_results['baseline_retriever'] = base_retriever_metrics
|
||||
logger.info(f"Baseline retriever metrics: {base_retriever_metrics}")
|
||||
|
||||
if reranker_model:
|
||||
base_reranker = BGERerankerModel(args.base_reranker)
|
||||
base_reranker.to(device)
|
||||
base_reranker.eval()
|
||||
|
||||
base_reranker_metrics = evaluate_reranker(args, base_reranker, reranker_tokenizer, eval_data)
|
||||
all_results['baseline_reranker'] = base_reranker_metrics
|
||||
logger.info(f"Baseline reranker metrics: {base_reranker_metrics}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to load baseline models: {e}")
|
||||
|
||||
# Save results
|
||||
results_file = os.path.join(args.output_dir, "evaluation_results.json")
|
||||
with open(results_file, 'w') as f:
|
||||
json.dump(all_results, f, indent=2)
|
||||
|
||||
logger.info(f"Evaluation results saved to {results_file}")
|
||||
|
||||
# Print summary
|
||||
print("\n" + "=" * 50)
|
||||
print("EVALUATION SUMMARY")
|
||||
print("=" * 50)
|
||||
for model_type, metrics in all_results.items():
|
||||
print(f"\n{model_type.upper()}:")
|
||||
for metric, value in metrics.items():
|
||||
if isinstance(value, (int, float)):
|
||||
print(f" {metric}: {value:.4f}")
|
||||
else:
|
||||
print(f" {metric}: {value}")
|
||||
|
||||
# Interactive evaluation
|
||||
if args.interactive:
|
||||
run_interactive_evaluation(args, retriever_model, reranker_model, retriever_tokenizer, eval_data)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
207
scripts/quick_train_test.py
Normal file
207
scripts/quick_train_test.py
Normal file
@@ -0,0 +1,207 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Quick Training Test Script for BGE Models
|
||||
|
||||
This script performs a quick training test on small data subsets to verify
|
||||
that fine-tuning is working before committing to full training.
|
||||
|
||||
Usage:
|
||||
python scripts/quick_train_test.py \
|
||||
--data_dir "data/datasets/三国演义/splits" \
|
||||
--output_dir "./output/quick_test" \
|
||||
--samples_per_model 1000
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
import argparse
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
import logging
|
||||
|
||||
# Add project root to path
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def create_subset(input_file: str, output_file: str, num_samples: int):
|
||||
"""Create a subset of the dataset for quick testing"""
|
||||
os.makedirs(os.path.dirname(output_file), exist_ok=True)
|
||||
|
||||
with open(input_file, 'r', encoding='utf-8') as f_in:
|
||||
with open(output_file, 'w', encoding='utf-8') as f_out:
|
||||
for i, line in enumerate(f_in):
|
||||
if i >= num_samples:
|
||||
break
|
||||
f_out.write(line)
|
||||
|
||||
logger.info(f"Created subset: {output_file} with {num_samples} samples")
|
||||
|
||||
|
||||
def run_command(cmd: list, description: str):
|
||||
"""Run a command and handle errors"""
|
||||
logger.info(f"🚀 {description}")
|
||||
logger.info(f"Command: {' '.join(cmd)}")
|
||||
|
||||
try:
|
||||
result = subprocess.run(cmd, check=True, capture_output=True, text=True)
|
||||
logger.info(f"✅ {description} completed successfully")
|
||||
return True
|
||||
except subprocess.CalledProcessError as e:
|
||||
logger.error(f"❌ {description} failed:")
|
||||
logger.error(f"Error: {e.stderr}")
|
||||
return False
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Quick training test for BGE models")
|
||||
parser.add_argument("--data_dir", type=str, required=True, help="Directory with split datasets")
|
||||
parser.add_argument("--output_dir", type=str, default="./output/quick_test", help="Output directory")
|
||||
parser.add_argument("--samples_per_model", type=int, default=1000, help="Number of samples for quick test")
|
||||
parser.add_argument("--epochs", type=int, default=1, help="Number of training epochs")
|
||||
parser.add_argument("--batch_size", type=int, default=8, help="Batch size for training")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
data_dir = Path(args.data_dir)
|
||||
output_dir = Path(args.output_dir)
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
logger.info("🏃♂️ Starting Quick Training Test")
|
||||
logger.info(f"📁 Data directory: {data_dir}")
|
||||
logger.info(f"📁 Output directory: {output_dir}")
|
||||
logger.info(f"🔢 Samples per model: {args.samples_per_model}")
|
||||
|
||||
# Create subset directories
|
||||
subset_dir = output_dir / "subsets"
|
||||
subset_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Create BGE-M3 subsets
|
||||
m3_train_subset = subset_dir / "m3_train_subset.jsonl"
|
||||
m3_val_subset = subset_dir / "m3_val_subset.jsonl"
|
||||
|
||||
if (data_dir / "m3_train.jsonl").exists():
|
||||
create_subset(str(data_dir / "m3_train.jsonl"), str(m3_train_subset), args.samples_per_model)
|
||||
create_subset(str(data_dir / "m3_val.jsonl"), str(m3_val_subset), args.samples_per_model // 10)
|
||||
|
||||
# Create Reranker subsets
|
||||
reranker_train_subset = subset_dir / "reranker_train_subset.jsonl"
|
||||
reranker_val_subset = subset_dir / "reranker_val_subset.jsonl"
|
||||
|
||||
if (data_dir / "reranker_train.jsonl").exists():
|
||||
create_subset(str(data_dir / "reranker_train.jsonl"), str(reranker_train_subset), args.samples_per_model)
|
||||
create_subset(str(data_dir / "reranker_val.jsonl"), str(reranker_val_subset), args.samples_per_model // 10)
|
||||
|
||||
# Test BGE-M3 training
|
||||
if m3_train_subset.exists():
|
||||
logger.info("\n" + "="*60)
|
||||
logger.info("🔍 Testing BGE-M3 Training")
|
||||
logger.info("="*60)
|
||||
|
||||
m3_output = output_dir / "bge_m3_test"
|
||||
m3_cmd = [
|
||||
"python", "scripts/train_m3.py",
|
||||
"--train_data", str(m3_train_subset),
|
||||
"--eval_data", str(m3_val_subset),
|
||||
"--output_dir", str(m3_output),
|
||||
"--num_train_epochs", str(args.epochs),
|
||||
"--per_device_train_batch_size", str(args.batch_size),
|
||||
"--per_device_eval_batch_size", str(args.batch_size * 2),
|
||||
"--learning_rate", "1e-5",
|
||||
"--save_steps", "100",
|
||||
"--eval_steps", "100",
|
||||
"--logging_steps", "50",
|
||||
"--warmup_ratio", "0.1",
|
||||
"--gradient_accumulation_steps", "2"
|
||||
]
|
||||
|
||||
m3_success = run_command(m3_cmd, "BGE-M3 Quick Training")
|
||||
else:
|
||||
m3_success = False
|
||||
logger.warning("BGE-M3 training data not found")
|
||||
|
||||
# Test Reranker training
|
||||
if reranker_train_subset.exists():
|
||||
logger.info("\n" + "="*60)
|
||||
logger.info("🏆 Testing Reranker Training")
|
||||
logger.info("="*60)
|
||||
|
||||
reranker_output = output_dir / "reranker_test"
|
||||
reranker_cmd = [
|
||||
"python", "scripts/train_reranker.py",
|
||||
"--reranker_data", str(reranker_train_subset),
|
||||
"--eval_data", str(reranker_val_subset),
|
||||
"--output_dir", str(reranker_output),
|
||||
"--num_train_epochs", str(args.epochs),
|
||||
"--per_device_train_batch_size", str(args.batch_size),
|
||||
"--per_device_eval_batch_size", str(args.batch_size * 2),
|
||||
"--learning_rate", "6e-5",
|
||||
"--save_steps", "100",
|
||||
"--eval_steps", "100",
|
||||
"--logging_steps", "50",
|
||||
"--warmup_ratio", "0.1"
|
||||
]
|
||||
|
||||
reranker_success = run_command(reranker_cmd, "Reranker Quick Training")
|
||||
else:
|
||||
reranker_success = False
|
||||
logger.warning("Reranker training data not found")
|
||||
|
||||
# Quick validation
|
||||
logger.info("\n" + "="*60)
|
||||
logger.info("🧪 Running Quick Validation")
|
||||
logger.info("="*60)
|
||||
|
||||
validation_results = {}
|
||||
|
||||
# Validate BGE-M3 if trained
|
||||
if m3_success and (m3_output / "final_model").exists():
|
||||
m3_val_cmd = [
|
||||
"python", "scripts/quick_validation.py",
|
||||
"--retriever_model", str(m3_output / "final_model")
|
||||
]
|
||||
validation_results['m3_validation'] = run_command(m3_val_cmd, "BGE-M3 Validation")
|
||||
|
||||
# Validate Reranker if trained
|
||||
if reranker_success and (reranker_output / "final_model").exists():
|
||||
reranker_val_cmd = [
|
||||
"python", "scripts/quick_validation.py",
|
||||
"--reranker_model", str(reranker_output / "final_model")
|
||||
]
|
||||
validation_results['reranker_validation'] = run_command(reranker_val_cmd, "Reranker Validation")
|
||||
|
||||
# Summary
|
||||
logger.info("\n" + "="*60)
|
||||
logger.info("📋 Quick Training Test Summary")
|
||||
logger.info("="*60)
|
||||
|
||||
summary = {
|
||||
"m3_training": m3_success,
|
||||
"reranker_training": reranker_success,
|
||||
"validation_results": validation_results,
|
||||
"recommendation": ""
|
||||
}
|
||||
|
||||
if m3_success and reranker_success:
|
||||
summary["recommendation"] = "✅ Both models trained successfully! Ready for full training."
|
||||
logger.info("✅ Both models trained successfully!")
|
||||
logger.info("🎯 You can now proceed with full training using the complete datasets.")
|
||||
elif m3_success or reranker_success:
|
||||
summary["recommendation"] = "⚠️ Partial success. Check failed model configuration."
|
||||
logger.info("⚠️ Partial success. One model failed - check logs above.")
|
||||
else:
|
||||
summary["recommendation"] = "❌ Training failed. Check configuration and data."
|
||||
logger.info("❌ Training tests failed. Check your configuration and data format.")
|
||||
|
||||
# Save summary
|
||||
with open(output_dir / "quick_test_summary.json", 'w', encoding='utf-8') as f:
|
||||
json.dump(summary, f, indent=2, ensure_ascii=False)
|
||||
|
||||
logger.info(f"📋 Summary saved to {output_dir / 'quick_test_summary.json'}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -265,7 +265,8 @@ class ValidationSuiteRunner:
|
||||
|
||||
try:
|
||||
cmd = [
|
||||
'python', 'scripts/compare_retriever.py',
|
||||
'python', 'scripts/compare_models.py',
|
||||
'--model_type', 'retriever',
|
||||
'--finetuned_model_path', self.config['retriever_model'],
|
||||
'--baseline_model_path', self.config.get('retriever_baseline', 'BAAI/bge-m3'),
|
||||
'--data_path', dataset_path,
|
||||
@@ -301,7 +302,8 @@ class ValidationSuiteRunner:
|
||||
|
||||
try:
|
||||
cmd = [
|
||||
'python', 'scripts/compare_reranker.py',
|
||||
'python', 'scripts/compare_models.py',
|
||||
'--model_type', 'reranker',
|
||||
'--finetuned_model_path', self.config['reranker_model'],
|
||||
'--baseline_model_path', self.config.get('reranker_baseline', 'BAAI/bge-reranker-base'),
|
||||
'--data_path', dataset_path,
|
||||
|
||||
@@ -341,7 +341,7 @@ def print_usage_info():
|
||||
print(" python scripts/train_m3.py --train_data data/raw/your_data.jsonl")
|
||||
print(" python scripts/train_reranker.py --train_data data/raw/your_data.jsonl")
|
||||
print("4. Evaluate models:")
|
||||
print(" python scripts/evaluate.py --retriever_model output/model --eval_data data/raw/test.jsonl")
|
||||
print(" python scripts/validate.py quick --retriever_model output/model")
|
||||
print("\nSample data created in: data/raw/sample_data.jsonl")
|
||||
print("Configuration file: config.toml")
|
||||
print("\nFor more information, see the README.md file.")
|
||||
|
||||
177
scripts/split_datasets.py
Normal file
177
scripts/split_datasets.py
Normal file
@@ -0,0 +1,177 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Dataset Splitting Script for BGE Fine-tuning
|
||||
|
||||
This script splits large datasets into train/validation/test sets while preserving
|
||||
data distribution and enabling proper evaluation.
|
||||
|
||||
Usage:
|
||||
python scripts/split_datasets.py \
|
||||
--input_dir "data/datasets/三国演义" \
|
||||
--output_dir "data/datasets/三国演义/splits" \
|
||||
--train_ratio 0.8 \
|
||||
--val_ratio 0.1 \
|
||||
--test_ratio 0.1
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
import argparse
|
||||
import random
|
||||
from pathlib import Path
|
||||
from typing import List, Dict, Tuple
|
||||
import logging
|
||||
|
||||
# Add project root to path
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def load_jsonl(file_path: str) -> List[Dict]:
|
||||
"""Load JSONL file"""
|
||||
data = []
|
||||
with open(file_path, 'r', encoding='utf-8') as f:
|
||||
for line_num, line in enumerate(f, 1):
|
||||
try:
|
||||
item = json.loads(line.strip())
|
||||
data.append(item)
|
||||
except json.JSONDecodeError as e:
|
||||
logger.warning(f"Skipping invalid JSON at line {line_num}: {e}")
|
||||
continue
|
||||
return data
|
||||
|
||||
|
||||
def save_jsonl(data: List[Dict], file_path: str):
|
||||
"""Save data to JSONL file"""
|
||||
os.makedirs(os.path.dirname(file_path), exist_ok=True)
|
||||
with open(file_path, 'w', encoding='utf-8') as f:
|
||||
for item in data:
|
||||
f.write(json.dumps(item, ensure_ascii=False) + '\n')
|
||||
|
||||
|
||||
def split_dataset(data: List[Dict], train_ratio: float, val_ratio: float, test_ratio: float, seed: int = 42) -> Tuple[List[Dict], List[Dict], List[Dict]]:
|
||||
"""Split dataset into train/validation/test sets"""
|
||||
# Validate ratios
|
||||
total_ratio = train_ratio + val_ratio + test_ratio
|
||||
if abs(total_ratio - 1.0) > 1e-6:
|
||||
raise ValueError(f"Ratios must sum to 1.0, got {total_ratio}")
|
||||
|
||||
# Shuffle data
|
||||
random.seed(seed)
|
||||
data_shuffled = data.copy()
|
||||
random.shuffle(data_shuffled)
|
||||
|
||||
# Calculate split indices
|
||||
total_size = len(data_shuffled)
|
||||
train_size = int(total_size * train_ratio)
|
||||
val_size = int(total_size * val_ratio)
|
||||
|
||||
# Split data
|
||||
train_data = data_shuffled[:train_size]
|
||||
val_data = data_shuffled[train_size:train_size + val_size]
|
||||
test_data = data_shuffled[train_size + val_size:]
|
||||
|
||||
return train_data, val_data, test_data
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Split BGE datasets for training")
|
||||
parser.add_argument("--input_dir", type=str, required=True, help="Directory containing m3.jsonl and reranker.jsonl")
|
||||
parser.add_argument("--output_dir", type=str, required=True, help="Output directory for split datasets")
|
||||
parser.add_argument("--train_ratio", type=float, default=0.8, help="Training set ratio")
|
||||
parser.add_argument("--val_ratio", type=float, default=0.1, help="Validation set ratio")
|
||||
parser.add_argument("--test_ratio", type=float, default=0.1, help="Test set ratio")
|
||||
parser.add_argument("--seed", type=int, default=42, help="Random seed for reproducibility")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
input_dir = Path(args.input_dir)
|
||||
output_dir = Path(args.output_dir)
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
logger.info(f"🔄 Splitting datasets from {input_dir}")
|
||||
logger.info(f"📁 Output directory: {output_dir}")
|
||||
logger.info(f"📊 Split ratios - Train: {args.train_ratio}, Val: {args.val_ratio}, Test: {args.test_ratio}")
|
||||
|
||||
# Process BGE-M3 dataset
|
||||
m3_file = input_dir / "m3.jsonl"
|
||||
if m3_file.exists():
|
||||
logger.info("Processing BGE-M3 dataset...")
|
||||
m3_data = load_jsonl(str(m3_file))
|
||||
logger.info(f"Loaded {len(m3_data)} BGE-M3 samples")
|
||||
|
||||
train_m3, val_m3, test_m3 = split_dataset(
|
||||
m3_data, args.train_ratio, args.val_ratio, args.test_ratio, args.seed
|
||||
)
|
||||
|
||||
# Save splits
|
||||
save_jsonl(train_m3, str(output_dir / "m3_train.jsonl"))
|
||||
save_jsonl(val_m3, str(output_dir / "m3_val.jsonl"))
|
||||
save_jsonl(test_m3, str(output_dir / "m3_test.jsonl"))
|
||||
|
||||
logger.info(f"✅ BGE-M3 splits saved: Train={len(train_m3)}, Val={len(val_m3)}, Test={len(test_m3)}")
|
||||
else:
|
||||
logger.warning(f"BGE-M3 file not found: {m3_file}")
|
||||
|
||||
# Process Reranker dataset
|
||||
reranker_file = input_dir / "reranker.jsonl"
|
||||
if reranker_file.exists():
|
||||
logger.info("Processing Reranker dataset...")
|
||||
reranker_data = load_jsonl(str(reranker_file))
|
||||
logger.info(f"Loaded {len(reranker_data)} reranker samples")
|
||||
|
||||
train_reranker, val_reranker, test_reranker = split_dataset(
|
||||
reranker_data, args.train_ratio, args.val_ratio, args.test_ratio, args.seed
|
||||
)
|
||||
|
||||
# Save splits
|
||||
save_jsonl(train_reranker, str(output_dir / "reranker_train.jsonl"))
|
||||
save_jsonl(val_reranker, str(output_dir / "reranker_val.jsonl"))
|
||||
save_jsonl(test_reranker, str(output_dir / "reranker_test.jsonl"))
|
||||
|
||||
logger.info(f"✅ Reranker splits saved: Train={len(train_reranker)}, Val={len(val_reranker)}, Test={len(test_reranker)}")
|
||||
else:
|
||||
logger.warning(f"Reranker file not found: {reranker_file}")
|
||||
|
||||
# Create summary
|
||||
summary = {
|
||||
"input_dir": str(input_dir),
|
||||
"output_dir": str(output_dir),
|
||||
"split_ratios": {
|
||||
"train": args.train_ratio,
|
||||
"validation": args.val_ratio,
|
||||
"test": args.test_ratio
|
||||
},
|
||||
"seed": args.seed,
|
||||
"datasets": {}
|
||||
}
|
||||
|
||||
if m3_file.exists():
|
||||
summary["datasets"]["bge_m3"] = {
|
||||
"total_samples": len(m3_data),
|
||||
"train_samples": len(train_m3),
|
||||
"val_samples": len(val_m3),
|
||||
"test_samples": len(test_m3)
|
||||
}
|
||||
|
||||
if reranker_file.exists():
|
||||
summary["datasets"]["reranker"] = {
|
||||
"total_samples": len(reranker_data),
|
||||
"train_samples": len(train_reranker),
|
||||
"val_samples": len(val_reranker),
|
||||
"test_samples": len(test_reranker)
|
||||
}
|
||||
|
||||
# Save summary
|
||||
with open(output_dir / "split_summary.json", 'w', encoding='utf-8') as f:
|
||||
json.dump(summary, f, indent=2, ensure_ascii=False)
|
||||
|
||||
logger.info(f"📋 Split summary saved to {output_dir / 'split_summary.json'}")
|
||||
logger.info("🎉 Dataset splitting completed successfully!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
483
scripts/validate.py
Normal file
483
scripts/validate.py
Normal file
@@ -0,0 +1,483 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
BGE Model Validation - Single Entry Point
|
||||
|
||||
This script provides a unified, streamlined interface for all BGE model validation needs.
|
||||
It replaces the multiple confusing validation scripts with one clear, easy-to-use tool.
|
||||
|
||||
Usage Examples:
|
||||
|
||||
# Quick validation (5 minutes)
|
||||
python scripts/validate.py quick \
|
||||
--retriever_model ./output/bge_m3_三国演义/final_model \
|
||||
--reranker_model ./output/bge_reranker_三国演义/final_model
|
||||
|
||||
# Comprehensive validation (30 minutes)
|
||||
python scripts/validate.py comprehensive \
|
||||
--retriever_model ./output/bge_m3_三国演义/final_model \
|
||||
--reranker_model ./output/bge_reranker_三国演义/final_model \
|
||||
--test_data_dir "data/datasets/三国演义/splits"
|
||||
|
||||
# Compare against baselines
|
||||
python scripts/validate.py compare \
|
||||
--retriever_model ./output/bge_m3_三国演义/final_model \
|
||||
--reranker_model ./output/bge_reranker_三国演义/final_model \
|
||||
--retriever_data "data/datasets/三国演义/splits/m3_test.jsonl" \
|
||||
--reranker_data "data/datasets/三国演义/splits/reranker_test.jsonl"
|
||||
|
||||
# Performance benchmarking only
|
||||
python scripts/validate.py benchmark \
|
||||
--retriever_model ./output/bge_m3_三国演义/final_model \
|
||||
--reranker_model ./output/bge_reranker_三国演义/final_model
|
||||
|
||||
# Complete validation suite (includes all above)
|
||||
python scripts/validate.py all \
|
||||
--retriever_model ./output/bge_m3_三国演义/final_model \
|
||||
--reranker_model ./output/bge_reranker_三国演义/final_model \
|
||||
--test_data_dir "data/datasets/三国演义/splits"
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import argparse
|
||||
import logging
|
||||
import json
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
from datetime import datetime
|
||||
from typing import Dict, List, Optional, Any
|
||||
|
||||
# Add project root to path
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ValidationOrchestrator:
|
||||
"""Orchestrates all validation activities through a single interface"""
|
||||
|
||||
def __init__(self, output_dir: str = "./validation_results"):
|
||||
self.output_dir = Path(output_dir)
|
||||
self.output_dir.mkdir(parents=True, exist_ok=True)
|
||||
self.results = {}
|
||||
|
||||
def run_quick_validation(self, args) -> Dict[str, Any]:
|
||||
"""Run quick validation using the existing quick_validation.py script"""
|
||||
logger.info("🏃♂️ Running Quick Validation (5 minutes)")
|
||||
|
||||
cmd = ["python", "scripts/quick_validation.py"]
|
||||
|
||||
if args.retriever_model:
|
||||
cmd.extend(["--retriever_model", args.retriever_model])
|
||||
if args.reranker_model:
|
||||
cmd.extend(["--reranker_model", args.reranker_model])
|
||||
if args.retriever_model and args.reranker_model:
|
||||
cmd.append("--test_pipeline")
|
||||
|
||||
# Save output to file
|
||||
output_file = self.output_dir / "quick_validation.json"
|
||||
cmd.extend(["--output_file", str(output_file)])
|
||||
|
||||
try:
|
||||
result = subprocess.run(cmd, check=True, capture_output=True, text=True)
|
||||
logger.info("✅ Quick validation completed successfully")
|
||||
|
||||
# Load results if available
|
||||
if output_file.exists():
|
||||
with open(output_file, 'r') as f:
|
||||
return json.load(f)
|
||||
return {"status": "completed", "output": result.stdout}
|
||||
|
||||
except subprocess.CalledProcessError as e:
|
||||
logger.error(f"❌ Quick validation failed: {e.stderr}")
|
||||
return {"status": "failed", "error": str(e)}
|
||||
|
||||
def run_comprehensive_validation(self, args) -> Dict[str, Any]:
|
||||
"""Run comprehensive validation"""
|
||||
logger.info("🔬 Running Comprehensive Validation (30 minutes)")
|
||||
|
||||
cmd = ["python", "scripts/comprehensive_validation.py"]
|
||||
|
||||
if args.retriever_model:
|
||||
cmd.extend(["--retriever_finetuned", args.retriever_model])
|
||||
if args.reranker_model:
|
||||
cmd.extend(["--reranker_finetuned", args.reranker_model])
|
||||
|
||||
# Add test datasets
|
||||
if args.test_data_dir:
|
||||
test_dir = Path(args.test_data_dir)
|
||||
test_datasets = []
|
||||
if (test_dir / "m3_test.jsonl").exists():
|
||||
test_datasets.append(str(test_dir / "m3_test.jsonl"))
|
||||
if (test_dir / "reranker_test.jsonl").exists():
|
||||
test_datasets.append(str(test_dir / "reranker_test.jsonl"))
|
||||
|
||||
if test_datasets:
|
||||
cmd.extend(["--test_datasets"] + test_datasets)
|
||||
|
||||
# Configuration
|
||||
cmd.extend([
|
||||
"--output_dir", str(self.output_dir / "comprehensive"),
|
||||
"--batch_size", str(args.batch_size),
|
||||
"--k_values", "1", "3", "5", "10", "20"
|
||||
])
|
||||
|
||||
try:
|
||||
result = subprocess.run(cmd, check=True, capture_output=True, text=True)
|
||||
logger.info("✅ Comprehensive validation completed successfully")
|
||||
|
||||
# Load results if available
|
||||
results_file = self.output_dir / "comprehensive" / "validation_results.json"
|
||||
if results_file.exists():
|
||||
with open(results_file, 'r') as f:
|
||||
return json.load(f)
|
||||
return {"status": "completed", "output": result.stdout}
|
||||
|
||||
except subprocess.CalledProcessError as e:
|
||||
logger.error(f"❌ Comprehensive validation failed: {e.stderr}")
|
||||
return {"status": "failed", "error": str(e)}
|
||||
|
||||
def run_comparison(self, args) -> Dict[str, Any]:
|
||||
"""Run model comparison against baselines"""
|
||||
logger.info("⚖️ Running Model Comparison")
|
||||
|
||||
results = {}
|
||||
|
||||
# Compare retriever if available
|
||||
if args.retriever_model and args.retriever_data:
|
||||
cmd = [
|
||||
"python", "scripts/compare_models.py",
|
||||
"--model_type", "retriever",
|
||||
"--finetuned_model", args.retriever_model,
|
||||
"--baseline_model", args.baseline_retriever or "BAAI/bge-m3",
|
||||
"--data_path", args.retriever_data,
|
||||
"--batch_size", str(args.batch_size),
|
||||
"--output_dir", str(self.output_dir / "comparison")
|
||||
]
|
||||
|
||||
if args.max_samples:
|
||||
cmd.extend(["--max_samples", str(args.max_samples)])
|
||||
|
||||
try:
|
||||
subprocess.run(cmd, check=True, capture_output=True, text=True)
|
||||
logger.info("✅ Retriever comparison completed")
|
||||
|
||||
# Load results
|
||||
results_file = self.output_dir / "comparison" / "retriever_comparison.json"
|
||||
if results_file.exists():
|
||||
with open(results_file, 'r') as f:
|
||||
results['retriever'] = json.load(f)
|
||||
|
||||
except subprocess.CalledProcessError as e:
|
||||
logger.error(f"❌ Retriever comparison failed: {e.stderr}")
|
||||
results['retriever'] = {"status": "failed", "error": str(e)}
|
||||
|
||||
# Compare reranker if available
|
||||
if args.reranker_model and args.reranker_data:
|
||||
cmd = [
|
||||
"python", "scripts/compare_models.py",
|
||||
"--model_type", "reranker",
|
||||
"--finetuned_model", args.reranker_model,
|
||||
"--baseline_model", args.baseline_reranker or "BAAI/bge-reranker-base",
|
||||
"--data_path", args.reranker_data,
|
||||
"--batch_size", str(args.batch_size),
|
||||
"--output_dir", str(self.output_dir / "comparison")
|
||||
]
|
||||
|
||||
if args.max_samples:
|
||||
cmd.extend(["--max_samples", str(args.max_samples)])
|
||||
|
||||
try:
|
||||
subprocess.run(cmd, check=True, capture_output=True, text=True)
|
||||
logger.info("✅ Reranker comparison completed")
|
||||
|
||||
# Load results
|
||||
results_file = self.output_dir / "comparison" / "reranker_comparison.json"
|
||||
if results_file.exists():
|
||||
with open(results_file, 'r') as f:
|
||||
results['reranker'] = json.load(f)
|
||||
|
||||
except subprocess.CalledProcessError as e:
|
||||
logger.error(f"❌ Reranker comparison failed: {e.stderr}")
|
||||
results['reranker'] = {"status": "failed", "error": str(e)}
|
||||
|
||||
return results
|
||||
|
||||
def run_benchmark(self, args) -> Dict[str, Any]:
|
||||
"""Run performance benchmarking"""
|
||||
logger.info("⚡ Running Performance Benchmarking")
|
||||
|
||||
results = {}
|
||||
|
||||
# Benchmark retriever
|
||||
if args.retriever_model:
|
||||
retriever_data = args.retriever_data or self._find_test_data(args.test_data_dir, "m3_test.jsonl")
|
||||
if retriever_data:
|
||||
cmd = [
|
||||
"python", "scripts/benchmark.py",
|
||||
"--model_type", "retriever",
|
||||
"--model_path", args.retriever_model,
|
||||
"--data_path", retriever_data,
|
||||
"--batch_size", str(args.batch_size),
|
||||
"--output", str(self.output_dir / "retriever_benchmark.txt")
|
||||
]
|
||||
|
||||
if args.max_samples:
|
||||
cmd.extend(["--max_samples", str(args.max_samples)])
|
||||
|
||||
try:
|
||||
result = subprocess.run(cmd, check=True, capture_output=True, text=True)
|
||||
logger.info("✅ Retriever benchmarking completed")
|
||||
results['retriever'] = {"status": "completed", "output": result.stdout}
|
||||
except subprocess.CalledProcessError as e:
|
||||
logger.error(f"❌ Retriever benchmarking failed: {e.stderr}")
|
||||
results['retriever'] = {"status": "failed", "error": str(e)}
|
||||
|
||||
# Benchmark reranker
|
||||
if args.reranker_model:
|
||||
reranker_data = args.reranker_data or self._find_test_data(args.test_data_dir, "reranker_test.jsonl")
|
||||
if reranker_data:
|
||||
cmd = [
|
||||
"python", "scripts/benchmark.py",
|
||||
"--model_type", "reranker",
|
||||
"--model_path", args.reranker_model,
|
||||
"--data_path", reranker_data,
|
||||
"--batch_size", str(args.batch_size),
|
||||
"--output", str(self.output_dir / "reranker_benchmark.txt")
|
||||
]
|
||||
|
||||
if args.max_samples:
|
||||
cmd.extend(["--max_samples", str(args.max_samples)])
|
||||
|
||||
try:
|
||||
result = subprocess.run(cmd, check=True, capture_output=True, text=True)
|
||||
logger.info("✅ Reranker benchmarking completed")
|
||||
results['reranker'] = {"status": "completed", "output": result.stdout}
|
||||
except subprocess.CalledProcessError as e:
|
||||
logger.error(f"❌ Reranker benchmarking failed: {e.stderr}")
|
||||
results['reranker'] = {"status": "failed", "error": str(e)}
|
||||
|
||||
return results
|
||||
|
||||
def run_all_validation(self, args) -> Dict[str, Any]:
|
||||
"""Run complete validation suite"""
|
||||
logger.info("🎯 Running Complete Validation Suite")
|
||||
|
||||
all_results = {
|
||||
'start_time': datetime.now().isoformat(),
|
||||
'validation_modes': []
|
||||
}
|
||||
|
||||
# 1. Quick validation
|
||||
logger.info("\n" + "="*60)
|
||||
logger.info("PHASE 1: QUICK VALIDATION")
|
||||
logger.info("="*60)
|
||||
all_results['quick'] = self.run_quick_validation(args)
|
||||
all_results['validation_modes'].append('quick')
|
||||
|
||||
# 2. Comprehensive validation
|
||||
logger.info("\n" + "="*60)
|
||||
logger.info("PHASE 2: COMPREHENSIVE VALIDATION")
|
||||
logger.info("="*60)
|
||||
all_results['comprehensive'] = self.run_comprehensive_validation(args)
|
||||
all_results['validation_modes'].append('comprehensive')
|
||||
|
||||
# 3. Comparison validation
|
||||
if self._has_comparison_data(args):
|
||||
logger.info("\n" + "="*60)
|
||||
logger.info("PHASE 3: BASELINE COMPARISON")
|
||||
logger.info("="*60)
|
||||
all_results['comparison'] = self.run_comparison(args)
|
||||
all_results['validation_modes'].append('comparison')
|
||||
|
||||
# 4. Performance benchmarking
|
||||
logger.info("\n" + "="*60)
|
||||
logger.info("PHASE 4: PERFORMANCE BENCHMARKING")
|
||||
logger.info("="*60)
|
||||
all_results['benchmark'] = self.run_benchmark(args)
|
||||
all_results['validation_modes'].append('benchmark')
|
||||
|
||||
all_results['end_time'] = datetime.now().isoformat()
|
||||
|
||||
# Generate summary report
|
||||
self._generate_summary_report(all_results)
|
||||
|
||||
return all_results
|
||||
|
||||
def _find_test_data(self, test_data_dir: Optional[str], filename: str) -> Optional[str]:
|
||||
"""Find test data file in the specified directory"""
|
||||
if not test_data_dir:
|
||||
return None
|
||||
|
||||
test_path = Path(test_data_dir) / filename
|
||||
return str(test_path) if test_path.exists() else None
|
||||
|
||||
def _has_comparison_data(self, args) -> bool:
|
||||
"""Check if we have data needed for comparison"""
|
||||
has_retriever_data = bool(args.retriever_data or self._find_test_data(args.test_data_dir, "m3_test.jsonl"))
|
||||
has_reranker_data = bool(args.reranker_data or self._find_test_data(args.test_data_dir, "reranker_test.jsonl"))
|
||||
|
||||
return (args.retriever_model and has_retriever_data) or (args.reranker_model and has_reranker_data)
|
||||
|
||||
def _generate_summary_report(self, results: Dict[str, Any]):
|
||||
"""Generate a comprehensive summary report"""
|
||||
logger.info("\n" + "="*80)
|
||||
logger.info("📋 VALIDATION SUITE SUMMARY REPORT")
|
||||
logger.info("="*80)
|
||||
|
||||
report_lines = []
|
||||
report_lines.append("# BGE Model Validation Report")
|
||||
report_lines.append(f"Generated: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
|
||||
report_lines.append("")
|
||||
|
||||
# Summary
|
||||
completed_phases = len(results['validation_modes'])
|
||||
total_phases = 4
|
||||
success_rate = completed_phases / total_phases * 100
|
||||
|
||||
report_lines.append(f"## Overall Summary")
|
||||
report_lines.append(f"- Validation phases completed: {completed_phases}/{total_phases} ({success_rate:.1f}%)")
|
||||
report_lines.append(f"- Start time: {results['start_time']}")
|
||||
report_lines.append(f"- End time: {results['end_time']}")
|
||||
report_lines.append("")
|
||||
|
||||
# Phase results
|
||||
for phase in ['quick', 'comprehensive', 'comparison', 'benchmark']:
|
||||
if phase in results:
|
||||
phase_result = results[phase]
|
||||
status = "✅ SUCCESS" if phase_result.get('status') != 'failed' else "❌ FAILED"
|
||||
report_lines.append(f"### {phase.title()} Validation: {status}")
|
||||
|
||||
if phase == 'comparison' and isinstance(phase_result, dict):
|
||||
if 'retriever' in phase_result:
|
||||
retriever_summary = phase_result['retriever'].get('summary', {})
|
||||
if 'status' in retriever_summary:
|
||||
report_lines.append(f" - Retriever: {retriever_summary['status']}")
|
||||
if 'reranker' in phase_result:
|
||||
reranker_summary = phase_result['reranker'].get('summary', {})
|
||||
if 'status' in reranker_summary:
|
||||
report_lines.append(f" - Reranker: {reranker_summary['status']}")
|
||||
|
||||
report_lines.append("")
|
||||
|
||||
# Recommendations
|
||||
report_lines.append("## Recommendations")
|
||||
|
||||
if success_rate == 100:
|
||||
report_lines.append("- 🌟 All validation phases completed successfully!")
|
||||
report_lines.append("- Your models are ready for production deployment.")
|
||||
elif success_rate >= 75:
|
||||
report_lines.append("- ✅ Most validation phases successful.")
|
||||
report_lines.append("- Review failed phases and address any issues.")
|
||||
elif success_rate >= 50:
|
||||
report_lines.append("- ⚠️ Partial validation success.")
|
||||
report_lines.append("- Significant issues detected. Review logs carefully.")
|
||||
else:
|
||||
report_lines.append("- ❌ Multiple validation failures.")
|
||||
report_lines.append("- Major issues with model training or configuration.")
|
||||
|
||||
report_lines.append("")
|
||||
report_lines.append("## Files Generated")
|
||||
report_lines.append(f"- Summary report: {self.output_dir}/validation_summary.md")
|
||||
report_lines.append(f"- Detailed results: {self.output_dir}/validation_results.json")
|
||||
if (self.output_dir / "comparison").exists():
|
||||
report_lines.append(f"- Comparison results: {self.output_dir}/comparison/")
|
||||
if (self.output_dir / "comprehensive").exists():
|
||||
report_lines.append(f"- Comprehensive results: {self.output_dir}/comprehensive/")
|
||||
|
||||
# Save report
|
||||
with open(self.output_dir / "validation_summary.md", 'w', encoding='utf-8') as f:
|
||||
f.write('\n'.join(report_lines))
|
||||
|
||||
with open(self.output_dir / "validation_results.json", 'w', encoding='utf-8') as f:
|
||||
json.dump(results, f, indent=2, default=str)
|
||||
|
||||
# Print summary to console
|
||||
print("\n".join(report_lines))
|
||||
|
||||
logger.info(f"📋 Summary report saved to: {self.output_dir}/validation_summary.md")
|
||||
logger.info(f"📋 Detailed results saved to: {self.output_dir}/validation_results.json")
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="BGE Model Validation - Single Entry Point",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
epilog=__doc__
|
||||
)
|
||||
|
||||
# Validation mode (required)
|
||||
parser.add_argument(
|
||||
"mode",
|
||||
choices=['quick', 'comprehensive', 'compare', 'benchmark', 'all'],
|
||||
help="Validation mode to run"
|
||||
)
|
||||
|
||||
# Model paths
|
||||
parser.add_argument("--retriever_model", type=str, help="Path to fine-tuned retriever model")
|
||||
parser.add_argument("--reranker_model", type=str, help="Path to fine-tuned reranker model")
|
||||
|
||||
# Data paths
|
||||
parser.add_argument("--test_data_dir", type=str, help="Directory containing test datasets")
|
||||
parser.add_argument("--retriever_data", type=str, help="Path to retriever test data")
|
||||
parser.add_argument("--reranker_data", type=str, help="Path to reranker test data")
|
||||
|
||||
# Baseline models for comparison
|
||||
parser.add_argument("--baseline_retriever", type=str, default="BAAI/bge-m3",
|
||||
help="Baseline retriever model")
|
||||
parser.add_argument("--baseline_reranker", type=str, default="BAAI/bge-reranker-base",
|
||||
help="Baseline reranker model")
|
||||
|
||||
# Configuration
|
||||
parser.add_argument("--batch_size", type=int, default=32, help="Batch size for evaluation")
|
||||
parser.add_argument("--max_samples", type=int, help="Maximum samples to test (for speed)")
|
||||
parser.add_argument("--output_dir", type=str, default="./validation_results",
|
||||
help="Directory to save results")
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
|
||||
# Validate arguments
|
||||
if not args.retriever_model and not args.reranker_model:
|
||||
logger.error("❌ At least one model (--retriever_model or --reranker_model) is required")
|
||||
return 1
|
||||
|
||||
logger.info("🚀 BGE Model Validation Started")
|
||||
logger.info(f"Mode: {args.mode}")
|
||||
logger.info(f"Output directory: {args.output_dir}")
|
||||
|
||||
# Initialize orchestrator
|
||||
orchestrator = ValidationOrchestrator(args.output_dir)
|
||||
|
||||
# Run validation based on mode
|
||||
try:
|
||||
if args.mode == 'quick':
|
||||
results = orchestrator.run_quick_validation(args)
|
||||
elif args.mode == 'comprehensive':
|
||||
results = orchestrator.run_comprehensive_validation(args)
|
||||
elif args.mode == 'compare':
|
||||
results = orchestrator.run_comparison(args)
|
||||
elif args.mode == 'benchmark':
|
||||
results = orchestrator.run_benchmark(args)
|
||||
elif args.mode == 'all':
|
||||
results = orchestrator.run_all_validation(args)
|
||||
else:
|
||||
logger.error(f"Unknown validation mode: {args.mode}")
|
||||
return 1
|
||||
|
||||
logger.info("✅ Validation completed successfully!")
|
||||
logger.info(f"📁 Results saved to: {args.output_dir}")
|
||||
|
||||
return 0
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Validation failed: {e}")
|
||||
return 1
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
exit(main())
|
||||
@@ -1,122 +0,0 @@
|
||||
"""
|
||||
Quick validation script for the fine-tuned BGE-M3 model
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import torch
|
||||
import numpy as np
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
# Add project root to Python path
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from models.bge_m3 import BGEM3Model
|
||||
|
||||
def test_model_embeddings(model_path, test_queries=None):
|
||||
"""Test if the model can generate embeddings"""
|
||||
|
||||
if test_queries is None:
|
||||
test_queries = [
|
||||
"什么是深度学习?", # Original training query
|
||||
"什么是机器学习?", # Related query
|
||||
"今天天气怎么样?" # Unrelated query
|
||||
]
|
||||
|
||||
print(f"🧪 Testing model from: {model_path}")
|
||||
|
||||
try:
|
||||
# Load tokenizer and model
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
||||
model = BGEM3Model.from_pretrained(model_path)
|
||||
model.eval()
|
||||
|
||||
print("✅ Model and tokenizer loaded successfully")
|
||||
|
||||
# Test embeddings generation
|
||||
embeddings = []
|
||||
|
||||
for i, query in enumerate(test_queries):
|
||||
print(f"\n📝 Testing query {i+1}: '{query}'")
|
||||
|
||||
# Tokenize
|
||||
inputs = tokenizer(
|
||||
query,
|
||||
return_tensors='pt',
|
||||
padding=True,
|
||||
truncation=True,
|
||||
max_length=512
|
||||
)
|
||||
|
||||
# Get embeddings
|
||||
with torch.no_grad():
|
||||
outputs = model.forward(
|
||||
input_ids=inputs['input_ids'],
|
||||
attention_mask=inputs['attention_mask'],
|
||||
return_dense=True,
|
||||
return_dict=True
|
||||
)
|
||||
|
||||
if 'dense' in outputs:
|
||||
embedding = outputs['dense'].cpu().numpy() # type: ignore[index]
|
||||
embeddings.append(embedding)
|
||||
|
||||
print(f" ✅ Embedding shape: {embedding.shape}")
|
||||
print(f" 📊 Embedding norm: {np.linalg.norm(embedding):.4f}")
|
||||
print(f" 📈 Mean activation: {embedding.mean():.6f}")
|
||||
print(f" 📉 Std activation: {embedding.std():.6f}")
|
||||
else:
|
||||
print(f" ❌ No dense embedding returned")
|
||||
return False
|
||||
|
||||
# Compare embeddings
|
||||
if len(embeddings) >= 2:
|
||||
print(f"\n🔗 Similarity Analysis:")
|
||||
for i in range(len(embeddings)):
|
||||
for j in range(i+1, len(embeddings)):
|
||||
sim = np.dot(embeddings[i].flatten(), embeddings[j].flatten()) / (
|
||||
np.linalg.norm(embeddings[i]) * np.linalg.norm(embeddings[j])
|
||||
)
|
||||
print(f" Query {i+1} vs Query {j+1}: {sim:.4f}")
|
||||
|
||||
print(f"\n🎉 Model validation successful!")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Model validation failed: {e}")
|
||||
return False
|
||||
|
||||
def compare_models(original_path, finetuned_path):
|
||||
"""Compare original vs fine-tuned model"""
|
||||
print("🔄 Comparing original vs fine-tuned model...")
|
||||
|
||||
test_query = "什么是深度学习?"
|
||||
|
||||
# Test original model
|
||||
print("\n📍 Original Model:")
|
||||
orig_success = test_model_embeddings(original_path, [test_query])
|
||||
|
||||
# Test fine-tuned model
|
||||
print("\n📍 Fine-tuned Model:")
|
||||
ft_success = test_model_embeddings(finetuned_path, [test_query])
|
||||
|
||||
return orig_success and ft_success
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("🚀 BGE-M3 Model Validation")
|
||||
print("=" * 50)
|
||||
|
||||
# Test fine-tuned model
|
||||
finetuned_model_path = "./test_m3_output/final_model"
|
||||
success = test_model_embeddings(finetuned_model_path)
|
||||
|
||||
# Optional: Compare with original model
|
||||
original_model_path = "./models/bge-m3"
|
||||
compare_models(original_model_path, finetuned_model_path)
|
||||
|
||||
if success:
|
||||
print("\n✅ Fine-tuning validation: PASSED")
|
||||
print(" The model can generate embeddings successfully!")
|
||||
else:
|
||||
print("\n❌ Fine-tuning validation: FAILED")
|
||||
print(" The model has issues generating embeddings!")
|
||||
@@ -1,150 +0,0 @@
|
||||
"""
|
||||
Quick validation script for the fine-tuned BGE Reranker model
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import torch
|
||||
import numpy as np
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
# Add project root to Python path
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from models.bge_reranker import BGERerankerModel
|
||||
|
||||
def test_reranker_scoring(model_path, test_queries=None):
|
||||
"""Test if the reranker can score query-passage pairs"""
|
||||
|
||||
if test_queries is None:
|
||||
# Test query with relevant and irrelevant passages
|
||||
test_queries = [
|
||||
{
|
||||
"query": "什么是深度学习?",
|
||||
"passages": [
|
||||
"深度学习是机器学习的一个子领域", # Relevant (should score high)
|
||||
"机器学习包含多种算法", # Somewhat relevant
|
||||
"今天天气很好", # Irrelevant (should score low)
|
||||
"深度学习使用神经网络进行学习" # Very relevant
|
||||
]
|
||||
}
|
||||
]
|
||||
|
||||
print(f"🧪 Testing reranker from: {model_path}")
|
||||
|
||||
try:
|
||||
# Load tokenizer and model
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
||||
model = BGERerankerModel.from_pretrained(model_path)
|
||||
model.eval()
|
||||
|
||||
print("✅ Reranker model and tokenizer loaded successfully")
|
||||
|
||||
for query_idx, test_case in enumerate(test_queries):
|
||||
query = test_case["query"]
|
||||
passages = test_case["passages"]
|
||||
|
||||
print(f"\n📝 Test Case {query_idx + 1}")
|
||||
print(f"Query: '{query}'")
|
||||
print(f"Passages to rank:")
|
||||
|
||||
scores = []
|
||||
|
||||
for i, passage in enumerate(passages):
|
||||
print(f" {i+1}. '{passage}'")
|
||||
|
||||
# Create query-passage pair
|
||||
sep_token = tokenizer.sep_token or '[SEP]'
|
||||
pair_text = f"{query} {sep_token} {passage}"
|
||||
|
||||
# Tokenize
|
||||
inputs = tokenizer(
|
||||
pair_text,
|
||||
return_tensors='pt',
|
||||
padding=True,
|
||||
truncation=True,
|
||||
max_length=512
|
||||
)
|
||||
|
||||
# Get relevance score
|
||||
with torch.no_grad():
|
||||
outputs = model(**inputs)
|
||||
|
||||
# Extract score (logits)
|
||||
if hasattr(outputs, 'logits'):
|
||||
score = outputs.logits.item()
|
||||
elif isinstance(outputs, torch.Tensor):
|
||||
score = outputs.item()
|
||||
else:
|
||||
score = outputs[0].item()
|
||||
|
||||
scores.append((i, passage, score))
|
||||
|
||||
# Sort by relevance score (highest first)
|
||||
scores.sort(key=lambda x: x[2], reverse=True)
|
||||
|
||||
print(f"\n🏆 Ranking Results:")
|
||||
for rank, (orig_idx, passage, score) in enumerate(scores, 1):
|
||||
print(f" Rank {rank}: Score {score:.4f} - '{passage[:50]}{'...' if len(passage) > 50 else ''}'")
|
||||
|
||||
# Check if most relevant passages ranked highest
|
||||
expected_relevant = ["深度学习是机器学习的一个子领域", "深度学习使用神经网络进行学习"]
|
||||
top_passages = [item[1] for item in scores[:2]]
|
||||
|
||||
relevant_in_top = any(rel in top_passages for rel in expected_relevant)
|
||||
irrelevant_in_top = "今天天气很好" in top_passages
|
||||
|
||||
if relevant_in_top and not irrelevant_in_top:
|
||||
print(f" ✅ Good ranking - relevant passages ranked higher")
|
||||
elif relevant_in_top:
|
||||
print(f" ⚠️ Partial ranking - relevant high but irrelevant also high")
|
||||
else:
|
||||
print(f" ❌ Poor ranking - irrelevant passages ranked too high")
|
||||
|
||||
print(f"\n🎉 Reranker validation successful!")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Reranker validation failed: {e}")
|
||||
return False
|
||||
|
||||
def compare_rerankers(original_path, finetuned_path):
|
||||
"""Compare original vs fine-tuned reranker"""
|
||||
print("🔄 Comparing original vs fine-tuned reranker...")
|
||||
|
||||
test_case = {
|
||||
"query": "什么是深度学习?",
|
||||
"passages": [
|
||||
"深度学习是机器学习的一个子领域",
|
||||
"今天天气很好"
|
||||
]
|
||||
}
|
||||
|
||||
# Test original model
|
||||
print("\n📍 Original Reranker:")
|
||||
orig_success = test_reranker_scoring(original_path, [test_case])
|
||||
|
||||
# Test fine-tuned model
|
||||
print("\n📍 Fine-tuned Reranker:")
|
||||
ft_success = test_reranker_scoring(finetuned_path, [test_case])
|
||||
|
||||
return orig_success and ft_success
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("🚀 BGE Reranker Model Validation")
|
||||
print("=" * 50)
|
||||
|
||||
# Test fine-tuned model
|
||||
finetuned_model_path = "./test_reranker_output/final_model"
|
||||
success = test_reranker_scoring(finetuned_model_path)
|
||||
|
||||
# Optional: Compare with original model
|
||||
original_model_path = "./models/bge-reranker-base"
|
||||
compare_rerankers(original_model_path, finetuned_model_path)
|
||||
|
||||
if success:
|
||||
print("\n✅ Reranker validation: PASSED")
|
||||
print(" The reranker can score query-passage pairs successfully!")
|
||||
else:
|
||||
print("\n❌ Reranker validation: FAILED")
|
||||
print(" The reranker has issues with scoring!")
|
||||
Reference in New Issue
Block a user