bge_finetune/scripts/comprehensive_validation.py
2025-07-22 16:55:25 +08:00

990 lines
39 KiB
Python

#!/usr/bin/env python3
"""
Comprehensive Validation Script for BGE Fine-tuned Models
This script provides a complete validation framework to determine if fine-tuned models
actually perform better than baseline models across multiple datasets and metrics.
Features:
- Tests both M3 retriever and reranker models
- Compares against baseline models on multiple datasets
- Provides statistical significance testing
- Tests pipeline performance (retrieval + reranking)
- Generates comprehensive reports
- Includes performance degradation detection
- Supports cross-dataset validation
Usage:
python scripts/comprehensive_validation.py \\
--retriever_finetuned ./output/bge-m3-enhanced/final_model \\
--reranker_finetuned ./output/bge-reranker/final_model \\
--output_dir ./validation_results
"""
import os
import sys
import json
import logging
import argparse
import pandas as pd
import numpy as np
from pathlib import Path
from datetime import datetime
from typing import Dict, List, Tuple, Optional, Any
from dataclasses import dataclass
from scipy import stats
import torch
from transformers import AutoTokenizer
from torch.utils.data import DataLoader
# Add project root to path
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from models.bge_m3 import BGEM3Model
from models.bge_reranker import BGERerankerModel
from data.dataset import BGEM3Dataset, BGERerankerDataset
from evaluation.metrics import (
compute_retrieval_metrics, compute_reranker_metrics,
compute_retrieval_metrics_comprehensive
)
from evaluation.evaluator import RetrievalEvaluator, InteractiveEvaluator
from data.preprocessing import DataPreprocessor
from utils.logging import setup_logging, get_logger
from utils.config_loader import get_config
@dataclass
class ValidationConfig:
"""Configuration for comprehensive validation"""
# Model paths
retriever_baseline: str = "BAAI/bge-m3"
reranker_baseline: str = "BAAI/bge-reranker-base"
retriever_finetuned: Optional[str] = None
reranker_finetuned: Optional[str] = None
# Test datasets
test_datasets: List[str] = None
# Evaluation settings
batch_size: int = 32
max_length: int = 512
k_values: List[int] = None
significance_level: float = 0.05
min_samples_for_significance: int = 30
# Pipeline settings
retrieval_top_k: int = 100
rerank_top_k: int = 10
# Output settings
output_dir: str = "./validation_results"
save_embeddings: bool = False
create_report: bool = True
def __post_init__(self):
if self.test_datasets is None:
self.test_datasets = [
"data/datasets/examples/embedding_data.jsonl",
"data/datasets/examples/reranker_data.jsonl",
"data/datasets/Reranker_AFQMC/dev.json",
"data/datasets/三国演义/reranker_train.jsonl"
]
if self.k_values is None:
self.k_values = [1, 3, 5, 10, 20]
class ComprehensiveValidator:
"""Comprehensive validation framework for BGE models"""
def __init__(self, config: ValidationConfig):
self.config = config
self.logger = get_logger(__name__)
self.results = {}
self.device = self._get_device()
# Create output directory
os.makedirs(config.output_dir, exist_ok=True)
# Setup logging
setup_logging(
output_dir=config.output_dir,
log_level=logging.INFO,
log_to_console=True,
log_to_file=True
)
def _get_device(self) -> torch.device:
"""Get optimal device for evaluation"""
try:
config = get_config()
device_type = config.get('hardware', {}).get('device', 'cuda')
if device_type == 'npu' and torch.cuda.is_available():
try:
import torch_npu
device = torch.device("npu:0")
self.logger.info("Using NPU device")
return device
except ImportError:
pass
if torch.cuda.is_available():
device = torch.device("cuda")
self.logger.info("Using CUDA device")
return device
except Exception as e:
self.logger.warning(f"Error detecting optimal device: {e}")
device = torch.device("cpu")
self.logger.info("Using CPU device")
return device
def validate_all(self) -> Dict[str, Any]:
"""Run comprehensive validation on all models and datasets"""
self.logger.info("🚀 Starting Comprehensive BGE Model Validation")
self.logger.info("=" * 80)
start_time = datetime.now()
try:
# 1. Validate retriever models
if self.config.retriever_finetuned:
self.logger.info("📊 Validating Retriever Models...")
retriever_results = self._validate_retrievers()
self.results['retriever'] = retriever_results
# 2. Validate reranker models
if self.config.reranker_finetuned:
self.logger.info("📊 Validating Reranker Models...")
reranker_results = self._validate_rerankers()
self.results['reranker'] = reranker_results
# 3. Validate pipeline (if both models available)
if self.config.retriever_finetuned and self.config.reranker_finetuned:
self.logger.info("📊 Validating Pipeline Performance...")
pipeline_results = self._validate_pipeline()
self.results['pipeline'] = pipeline_results
# 4. Generate comprehensive report
if self.config.create_report:
self._generate_report()
# 5. Statistical analysis
self._perform_statistical_analysis()
end_time = datetime.now()
duration = (end_time - start_time).total_seconds()
self.logger.info("✅ Comprehensive validation completed!")
self.logger.info(f"⏱️ Total validation time: {duration:.2f} seconds")
return self.results
except Exception as e:
self.logger.error(f"❌ Validation failed: {e}", exc_info=True)
raise
def _validate_retrievers(self) -> Dict[str, Any]:
"""Validate retriever models on all suitable datasets"""
results = {
'baseline_metrics': {},
'finetuned_metrics': {},
'comparisons': {},
'datasets_tested': []
}
for dataset_path in self.config.test_datasets:
if not os.path.exists(dataset_path):
self.logger.warning(f"Dataset not found: {dataset_path}")
continue
# Skip reranker-only datasets for retriever testing
if 'reranker' in dataset_path.lower() and 'AFQMC' in dataset_path:
continue
dataset_name = Path(dataset_path).stem
self.logger.info(f"Testing retriever on dataset: {dataset_name}")
try:
# Test baseline model
baseline_metrics = self._test_retriever_on_dataset(
model_path=self.config.retriever_baseline,
dataset_path=dataset_path,
is_baseline=True
)
# Test fine-tuned model
finetuned_metrics = self._test_retriever_on_dataset(
model_path=self.config.retriever_finetuned,
dataset_path=dataset_path,
is_baseline=False
)
# Compare results
comparison = self._compare_metrics(baseline_metrics, finetuned_metrics, "retriever")
results['baseline_metrics'][dataset_name] = baseline_metrics
results['finetuned_metrics'][dataset_name] = finetuned_metrics
results['comparisons'][dataset_name] = comparison
results['datasets_tested'].append(dataset_name)
self.logger.info(f"✅ Retriever validation complete for {dataset_name}")
except Exception as e:
self.logger.error(f"❌ Failed to test retriever on {dataset_name}: {e}")
continue
return results
def _validate_rerankers(self) -> Dict[str, Any]:
"""Validate reranker models on all suitable datasets"""
results = {
'baseline_metrics': {},
'finetuned_metrics': {},
'comparisons': {},
'datasets_tested': []
}
for dataset_path in self.config.test_datasets:
if not os.path.exists(dataset_path):
self.logger.warning(f"Dataset not found: {dataset_path}")
continue
dataset_name = Path(dataset_path).stem
self.logger.info(f"Testing reranker on dataset: {dataset_name}")
try:
# Test baseline model
baseline_metrics = self._test_reranker_on_dataset(
model_path=self.config.reranker_baseline,
dataset_path=dataset_path,
is_baseline=True
)
# Test fine-tuned model
finetuned_metrics = self._test_reranker_on_dataset(
model_path=self.config.reranker_finetuned,
dataset_path=dataset_path,
is_baseline=False
)
# Compare results
comparison = self._compare_metrics(baseline_metrics, finetuned_metrics, "reranker")
results['baseline_metrics'][dataset_name] = baseline_metrics
results['finetuned_metrics'][dataset_name] = finetuned_metrics
results['comparisons'][dataset_name] = comparison
results['datasets_tested'].append(dataset_name)
self.logger.info(f"✅ Reranker validation complete for {dataset_name}")
except Exception as e:
self.logger.error(f"❌ Failed to test reranker on {dataset_name}: {e}")
continue
return results
def _validate_pipeline(self) -> Dict[str, Any]:
"""Validate end-to-end retrieval + reranking pipeline"""
results = {
'baseline_pipeline': {},
'finetuned_pipeline': {},
'comparisons': {},
'datasets_tested': []
}
for dataset_path in self.config.test_datasets:
if not os.path.exists(dataset_path):
continue
# Only test on datasets suitable for pipeline evaluation
if 'embedding_data' not in dataset_path:
continue
dataset_name = Path(dataset_path).stem
self.logger.info(f"Testing pipeline on dataset: {dataset_name}")
try:
# Test baseline pipeline
baseline_metrics = self._test_pipeline_on_dataset(
retriever_path=self.config.retriever_baseline,
reranker_path=self.config.reranker_baseline,
dataset_path=dataset_path,
is_baseline=True
)
# Test fine-tuned pipeline
finetuned_metrics = self._test_pipeline_on_dataset(
retriever_path=self.config.retriever_finetuned,
reranker_path=self.config.reranker_finetuned,
dataset_path=dataset_path,
is_baseline=False
)
# Compare results
comparison = self._compare_metrics(baseline_metrics, finetuned_metrics, "pipeline")
results['baseline_pipeline'][dataset_name] = baseline_metrics
results['finetuned_pipeline'][dataset_name] = finetuned_metrics
results['comparisons'][dataset_name] = comparison
results['datasets_tested'].append(dataset_name)
self.logger.info(f"✅ Pipeline validation complete for {dataset_name}")
except Exception as e:
self.logger.error(f"❌ Failed to test pipeline on {dataset_name}: {e}")
continue
return results
def _test_retriever_on_dataset(self, model_path: str, dataset_path: str, is_baseline: bool) -> Dict[str, float]:
"""Test a retriever model on a specific dataset"""
try:
# Load model and tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
model = BGEM3Model.from_pretrained(model_path)
model.to(self.device)
model.eval()
# Create dataset
dataset = BGEM3Dataset(
data_path=dataset_path,
tokenizer=tokenizer,
query_max_length=self.config.max_length,
passage_max_length=self.config.max_length,
is_train=False
)
dataloader = DataLoader(
dataset,
batch_size=self.config.batch_size,
shuffle=False
)
# Create evaluator
evaluator = RetrievalEvaluator(
model=model,
tokenizer=tokenizer,
device=self.device,
k_values=self.config.k_values
)
# Run evaluation
metrics = evaluator.evaluate(
dataloader=dataloader,
desc=f"{'Baseline' if is_baseline else 'Fine-tuned'} Retriever"
)
return metrics
except Exception as e:
self.logger.error(f"Error testing retriever: {e}")
return {}
def _test_reranker_on_dataset(self, model_path: str, dataset_path: str, is_baseline: bool) -> Dict[str, float]:
"""Test a reranker model on a specific dataset"""
try:
# Load model and tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
model = BGERerankerModel.from_pretrained(model_path)
model.to(self.device)
model.eval()
# Create dataset
dataset = BGERerankerDataset(
data_path=dataset_path,
tokenizer=tokenizer,
query_max_length=self.config.max_length,
passage_max_length=self.config.max_length,
is_train=False
)
dataloader = DataLoader(
dataset,
batch_size=self.config.batch_size,
shuffle=False
)
# Collect predictions and labels
all_scores = []
all_labels = []
with torch.no_grad():
for batch in dataloader:
# Move to device
input_ids = batch['input_ids'].to(self.device)
attention_mask = batch['attention_mask'].to(self.device)
labels = batch['labels'].cpu().numpy()
# Get predictions
outputs = model(input_ids=input_ids, attention_mask=attention_mask)
scores = outputs.logits.cpu().numpy()
all_scores.append(scores)
all_labels.append(labels)
# Concatenate results
all_scores = np.concatenate(all_scores, axis=0)
all_labels = np.concatenate(all_labels, axis=0)
# Compute metrics
metrics = compute_reranker_metrics(
scores=all_scores,
labels=all_labels,
k_values=self.config.k_values
)
return metrics
except Exception as e:
self.logger.error(f"Error testing reranker: {e}")
return {}
def _test_pipeline_on_dataset(self, retriever_path: str, reranker_path: str, dataset_path: str, is_baseline: bool) -> Dict[str, float]:
"""Test end-to-end pipeline on a dataset"""
try:
# Load models
retriever_tokenizer = AutoTokenizer.from_pretrained(retriever_path, trust_remote_code=True)
retriever = BGEM3Model.from_pretrained(retriever_path)
retriever.to(self.device)
retriever.eval()
reranker_tokenizer = AutoTokenizer.from_pretrained(reranker_path, trust_remote_code=True)
reranker = BGERerankerModel.from_pretrained(reranker_path)
reranker.to(self.device)
reranker.eval()
# Load and preprocess data
data = []
with open(dataset_path, 'r', encoding='utf-8') as f:
for line in f:
try:
item = json.loads(line.strip())
if 'query' in item and 'pos' in item and 'neg' in item:
data.append(item)
except:
continue
if not data:
return {}
# Test pipeline performance
total_queries = 0
pipeline_metrics = {f'recall@{k}': [] for k in self.config.k_values}
pipeline_metrics.update({f'mrr@{k}': [] for k in self.config.k_values})
for item in data[:100]: # Limit to 100 queries for speed
query = item['query']
pos_docs = item['pos'][:2] # Take first 2 positive docs
neg_docs = item['neg'][:8] # Take first 8 negative docs
# Create candidate pool
candidates = pos_docs + neg_docs
labels = [1] * len(pos_docs) + [0] * len(neg_docs)
if len(candidates) < 3: # Need minimum candidates
continue
try:
# Step 1: Retrieval (encode query and candidates)
query_inputs = retriever_tokenizer(
query,
return_tensors='pt',
padding=True,
truncation=True,
max_length=self.config.max_length
).to(self.device)
with torch.no_grad():
query_outputs = retriever(**query_inputs, return_dense=True)
query_emb = query_outputs['dense']
# Encode candidates
candidate_embs = []
for candidate in candidates:
cand_inputs = retriever_tokenizer(
candidate,
return_tensors='pt',
padding=True,
truncation=True,
max_length=self.config.max_length
).to(self.device)
with torch.no_grad():
cand_outputs = retriever(**cand_inputs, return_dense=True)
candidate_embs.append(cand_outputs['dense'])
candidate_embs = torch.cat(candidate_embs, dim=0)
# Compute similarities
similarities = torch.cosine_similarity(query_emb, candidate_embs, dim=1)
# Get top-k for retrieval
top_k_retrieval = min(self.config.retrieval_top_k, len(candidates))
retrieval_indices = torch.argsort(similarities, descending=True)[:top_k_retrieval]
# Step 2: Reranking
rerank_candidates = [candidates[i] for i in retrieval_indices]
rerank_labels = [labels[i] for i in retrieval_indices]
# Score pairs with reranker
rerank_scores = []
for candidate in rerank_candidates:
pair_text = f"{query} [SEP] {candidate}"
pair_inputs = reranker_tokenizer(
pair_text,
return_tensors='pt',
padding=True,
truncation=True,
max_length=self.config.max_length
).to(self.device)
with torch.no_grad():
pair_outputs = reranker(**pair_inputs)
score = pair_outputs.logits.item()
rerank_scores.append(score)
# Final ranking
final_indices = np.argsort(rerank_scores)[::-1]
final_labels = [rerank_labels[i] for i in final_indices]
# Compute metrics
for k in self.config.k_values:
if k <= len(final_labels):
# Recall@k
relevant_in_topk = sum(final_labels[:k])
total_relevant = sum(rerank_labels)
recall = relevant_in_topk / total_relevant if total_relevant > 0 else 0
pipeline_metrics[f'recall@{k}'].append(recall)
# MRR@k
for i, label in enumerate(final_labels[:k]):
if label == 1:
pipeline_metrics[f'mrr@{k}'].append(1.0 / (i + 1))
break
else:
pipeline_metrics[f'mrr@{k}'].append(0.0)
total_queries += 1
except Exception as e:
self.logger.warning(f"Error processing query: {e}")
continue
# Average metrics
final_metrics = {}
for metric, values in pipeline_metrics.items():
if values:
final_metrics[metric] = np.mean(values)
else:
final_metrics[metric] = 0.0
final_metrics['total_queries'] = total_queries
return final_metrics
except Exception as e:
self.logger.error(f"Error testing pipeline: {e}")
return {}
def _compare_metrics(self, baseline: Dict[str, float], finetuned: Dict[str, float], model_type: str) -> Dict[str, Any]:
"""Compare baseline vs fine-tuned metrics with statistical analysis"""
comparison = {
'improvements': {},
'degradations': {},
'statistical_significance': {},
'summary': {}
}
# Compare common metrics
common_metrics = set(baseline.keys()) & set(finetuned.keys())
improvements = 0
degradations = 0
significant_improvements = 0
for metric in common_metrics:
if isinstance(baseline[metric], (int, float)) and isinstance(finetuned[metric], (int, float)):
baseline_val = baseline[metric]
finetuned_val = finetuned[metric]
# Calculate improvement
if baseline_val != 0:
improvement_pct = ((finetuned_val - baseline_val) / baseline_val) * 100
else:
improvement_pct = float('inf') if finetuned_val > 0 else 0
# Classify improvement/degradation
if finetuned_val > baseline_val:
comparison['improvements'][metric] = {
'baseline': baseline_val,
'finetuned': finetuned_val,
'improvement_pct': improvement_pct,
'absolute_improvement': finetuned_val - baseline_val
}
improvements += 1
elif finetuned_val < baseline_val:
comparison['degradations'][metric] = {
'baseline': baseline_val,
'finetuned': finetuned_val,
'degradation_pct': -improvement_pct,
'absolute_degradation': baseline_val - finetuned_val
}
degradations += 1
# Summary statistics
comparison['summary'] = {
'total_metrics': len(common_metrics),
'improvements': improvements,
'degradations': degradations,
'unchanged': len(common_metrics) - improvements - degradations,
'improvement_ratio': improvements / len(common_metrics) if common_metrics else 0,
'net_improvement': improvements - degradations
}
return comparison
def _perform_statistical_analysis(self):
"""Perform statistical analysis on validation results"""
self.logger.info("📈 Performing Statistical Analysis...")
analysis = {
'overall_summary': {},
'significance_tests': {},
'confidence_intervals': {},
'effect_sizes': {}
}
# Analyze each model type
for model_type in ['retriever', 'reranker', 'pipeline']:
if model_type not in self.results:
continue
model_results = self.results[model_type]
# Collect all improvements across datasets
all_improvements = []
significant_improvements = 0
for dataset, comparison in model_results.get('comparisons', {}).items():
for metric, improvement_data in comparison.get('improvements', {}).items():
improvement_pct = improvement_data.get('improvement_pct', 0)
if improvement_pct > 0:
all_improvements.append(improvement_pct)
# Simple significance test (can be enhanced)
if improvement_pct > 5.0: # 5% improvement threshold
significant_improvements += 1
if all_improvements:
analysis[model_type] = {
'mean_improvement_pct': np.mean(all_improvements),
'median_improvement_pct': np.median(all_improvements),
'std_improvement_pct': np.std(all_improvements),
'min_improvement_pct': np.min(all_improvements),
'max_improvement_pct': np.max(all_improvements),
'significant_improvements': significant_improvements,
'total_improvements': len(all_improvements)
}
self.results['statistical_analysis'] = analysis
def _generate_report(self):
"""Generate comprehensive validation report"""
self.logger.info("📝 Generating Comprehensive Report...")
report_path = os.path.join(self.config.output_dir, "validation_report.html")
html_content = self._create_html_report()
with open(report_path, 'w', encoding='utf-8') as f:
f.write(html_content)
# Also save JSON results
json_path = os.path.join(self.config.output_dir, "validation_results.json")
with open(json_path, 'w', encoding='utf-8') as f:
json.dump(self.results, f, indent=2, default=str)
self.logger.info(f"📊 Report saved to: {report_path}")
self.logger.info(f"📊 JSON results saved to: {json_path}")
def _create_html_report(self) -> str:
"""Create HTML report with comprehensive validation results"""
html = f"""
<!DOCTYPE html>
<html>
<head>
<title>BGE Model Validation Report</title>
<style>
body {{ font-family: Arial, sans-serif; margin: 20px; }}
.header {{ background: #f0f8ff; padding: 20px; border-radius: 10px; }}
.section {{ margin: 20px 0; }}
.metrics-table {{ border-collapse: collapse; width: 100%; }}
.metrics-table th, .metrics-table td {{ border: 1px solid #ddd; padding: 8px; text-align: left; }}
.metrics-table th {{ background-color: #f2f2f2; }}
.improvement {{ color: green; font-weight: bold; }}
.degradation {{ color: red; font-weight: bold; }}
.summary-box {{ background: #f9f9f9; padding: 15px; border-left: 4px solid #4CAF50; margin: 10px 0; }}
</style>
</head>
<body>
<div class="header">
<h1>🚀 BGE Model Comprehensive Validation Report</h1>
<p><strong>Generated:</strong> {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}</p>
<p><strong>Configuration:</strong></p>
<ul>
<li>Retriever Baseline: {self.config.retriever_baseline}</li>
<li>Retriever Fine-tuned: {self.config.retriever_finetuned or 'Not tested'}</li>
<li>Reranker Baseline: {self.config.reranker_baseline}</li>
<li>Reranker Fine-tuned: {self.config.reranker_finetuned or 'Not tested'}</li>
</ul>
</div>
"""
# Add executive summary
html += self._create_executive_summary()
# Add detailed results for each model type
for model_type in ['retriever', 'reranker', 'pipeline']:
if model_type in self.results:
html += self._create_model_section(model_type)
html += """
</body>
</html>
"""
return html
def _create_executive_summary(self) -> str:
"""Create executive summary section"""
html = """
<div class="section">
<h2>📊 Executive Summary</h2>
"""
# Overall verdict
total_improvements = 0
total_degradations = 0
for model_type in ['retriever', 'reranker', 'pipeline']:
if model_type in self.results and 'comparisons' in self.results[model_type]:
for dataset, comparison in self.results[model_type]['comparisons'].items():
total_improvements += len(comparison.get('improvements', {}))
total_degradations += len(comparison.get('degradations', {}))
if total_improvements > total_degradations:
verdict = "✅ <strong>FINE-TUNING SUCCESSFUL</strong> - Models show overall improvement"
verdict_color = "green"
elif total_improvements == total_degradations:
verdict = "⚠️ <strong>MIXED RESULTS</strong> - Equal improvements and degradations"
verdict_color = "orange"
else:
verdict = "❌ <strong>FINE-TUNING ISSUES</strong> - More degradations than improvements"
verdict_color = "red"
html += f"""
<div class="summary-box" style="border-left-color: {verdict_color};">
<h3>{verdict}</h3>
<ul>
<li>Total Metric Improvements: {total_improvements}</li>
<li>Total Metric Degradations: {total_degradations}</li>
<li>Net Improvement: {total_improvements - total_degradations}</li>
</ul>
</div>
"""
html += """
</div>
"""
return html
def _create_model_section(self, model_type: str) -> str:
"""Create detailed section for a model type"""
html = f"""
<div class="section">
<h2>📈 {model_type.title()} Model Results</h2>
"""
model_results = self.results[model_type]
for dataset in model_results.get('datasets_tested', []):
if dataset in model_results['comparisons']:
comparison = model_results['comparisons'][dataset]
html += f"""
<h3>Dataset: {dataset}</h3>
<table class="metrics-table">
<thead>
<tr>
<th>Metric</th>
<th>Baseline</th>
<th>Fine-tuned</th>
<th>Change</th>
<th>Status</th>
</tr>
</thead>
<tbody>
"""
# Add improvements
for metric, data in comparison.get('improvements', {}).items():
html += f"""
<tr>
<td>{metric}</td>
<td>{data['baseline']:.4f}</td>
<td>{data['finetuned']:.4f}</td>
<td class="improvement">+{data['improvement_pct']:.2f}%</td>
<td class="improvement">✅ Improved</td>
</tr>
"""
# Add degradations
for metric, data in comparison.get('degradations', {}).items():
html += f"""
<tr>
<td>{metric}</td>
<td>{data['baseline']:.4f}</td>
<td>{data['finetuned']:.4f}</td>
<td class="degradation">-{data['degradation_pct']:.2f}%</td>
<td class="degradation">❌ Degraded</td>
</tr>
"""
html += """
</tbody>
</table>
"""
html += """
</div>
"""
return html
def parse_args():
"""Parse command line arguments"""
parser = argparse.ArgumentParser(
description="Comprehensive validation of BGE fine-tuned models",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Examples:
# Validate both retriever and reranker
python scripts/comprehensive_validation.py \\
--retriever_finetuned ./output/bge-m3-enhanced/final_model \\
--reranker_finetuned ./output/bge-reranker/final_model
# Validate only retriever
python scripts/comprehensive_validation.py \\
--retriever_finetuned ./output/bge-m3-enhanced/final_model
# Custom datasets and output
python scripts/comprehensive_validation.py \\
--retriever_finetuned ./output/bge-m3-enhanced/final_model \\
--test_datasets data/my_test1.jsonl data/my_test2.jsonl \\
--output_dir ./my_validation_results
"""
)
# Model paths
parser.add_argument("--retriever_baseline", type=str, default="BAAI/bge-m3",
help="Path to baseline retriever model")
parser.add_argument("--reranker_baseline", type=str, default="BAAI/bge-reranker-base",
help="Path to baseline reranker model")
parser.add_argument("--retriever_finetuned", type=str, default=None,
help="Path to fine-tuned retriever model")
parser.add_argument("--reranker_finetuned", type=str, default=None,
help="Path to fine-tuned reranker model")
# Test data
parser.add_argument("--test_datasets", type=str, nargs="+", default=None,
help="Paths to test datasets (JSONL format)")
# Evaluation settings
parser.add_argument("--batch_size", type=int, default=32,
help="Batch size for evaluation")
parser.add_argument("--max_length", type=int, default=512,
help="Maximum sequence length")
parser.add_argument("--k_values", type=int, nargs="+", default=[1, 3, 5, 10, 20],
help="K values for evaluation metrics")
# Output settings
parser.add_argument("--output_dir", type=str, default="./validation_results",
help="Directory to save validation results")
parser.add_argument("--save_embeddings", action="store_true",
help="Save embeddings for further analysis")
parser.add_argument("--no_report", action="store_true",
help="Skip HTML report generation")
# Pipeline settings
parser.add_argument("--retrieval_top_k", type=int, default=100,
help="Top-k for retrieval stage")
parser.add_argument("--rerank_top_k", type=int, default=10,
help="Top-k for reranking stage")
return parser.parse_args()
def main():
"""Main validation function"""
args = parse_args()
# Validate arguments
if not args.retriever_finetuned and not args.reranker_finetuned:
print("❌ Error: At least one fine-tuned model must be specified")
print(" Use --retriever_finetuned or --reranker_finetuned")
return 1
# Create configuration
config = ValidationConfig(
retriever_baseline=args.retriever_baseline,
reranker_baseline=args.reranker_baseline,
retriever_finetuned=args.retriever_finetuned,
reranker_finetuned=args.reranker_finetuned,
test_datasets=args.test_datasets,
batch_size=args.batch_size,
max_length=args.max_length,
k_values=args.k_values,
output_dir=args.output_dir,
save_embeddings=args.save_embeddings,
create_report=not args.no_report,
retrieval_top_k=args.retrieval_top_k,
rerank_top_k=args.rerank_top_k
)
# Run validation
try:
validator = ComprehensiveValidator(config)
results = validator.validate_all()
print("\n" + "=" * 80)
print("🎉 COMPREHENSIVE VALIDATION COMPLETED!")
print("=" * 80)
print(f"📊 Results saved to: {config.output_dir}")
print("\n📈 Quick Summary:")
# Print quick summary
for model_type in ['retriever', 'reranker', 'pipeline']:
if model_type in results:
model_results = results[model_type]
total_datasets = len(model_results.get('datasets_tested', []))
total_improvements = sum(len(comp.get('improvements', {}))
for comp in model_results.get('comparisons', {}).values())
total_degradations = sum(len(comp.get('degradations', {}))
for comp in model_results.get('comparisons', {}).values())
status = "" if total_improvements > total_degradations else "⚠️" if total_improvements == total_degradations else ""
print(f" {status} {model_type.title()}: {total_improvements} improvements, {total_degradations} degradations across {total_datasets} datasets")
return 0
except KeyboardInterrupt:
print("\n❌ Validation interrupted by user")
return 1
except Exception as e:
print(f"\n❌ Validation failed: {e}")
return 1
if __name__ == "__main__":
exit(main())