483 lines
21 KiB
Python
483 lines
21 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
BGE Model Validation - Single Entry Point
|
|
|
|
This script provides a unified, streamlined interface for all BGE model validation needs.
|
|
It replaces the multiple confusing validation scripts with one clear, easy-to-use tool.
|
|
|
|
Usage Examples:
|
|
|
|
# Quick validation (5 minutes)
|
|
python scripts/validate.py quick \
|
|
--retriever_model ./output/bge_m3_三国演义/final_model \
|
|
--reranker_model ./output/bge_reranker_三国演义/final_model
|
|
|
|
# Comprehensive validation (30 minutes)
|
|
python scripts/validate.py comprehensive \
|
|
--retriever_model ./output/bge_m3_三国演义/final_model \
|
|
--reranker_model ./output/bge_reranker_三国演义/final_model \
|
|
--test_data_dir "data/datasets/三国演义/splits"
|
|
|
|
# Compare against baselines
|
|
python scripts/validate.py compare \
|
|
--retriever_model ./output/bge_m3_三国演义/final_model \
|
|
--reranker_model ./output/bge_reranker_三国演义/final_model \
|
|
--retriever_data "data/datasets/三国演义/splits/m3_test.jsonl" \
|
|
--reranker_data "data/datasets/三国演义/splits/reranker_test.jsonl"
|
|
|
|
# Performance benchmarking only
|
|
python scripts/validate.py benchmark \
|
|
--retriever_model ./output/bge_m3_三国演义/final_model \
|
|
--reranker_model ./output/bge_reranker_三国演义/final_model
|
|
|
|
# Complete validation suite (includes all above)
|
|
python scripts/validate.py all \
|
|
--retriever_model ./output/bge_m3_三国演义/final_model \
|
|
--reranker_model ./output/bge_reranker_三国演义/final_model \
|
|
--test_data_dir "data/datasets/三国演义/splits"
|
|
"""
|
|
|
|
import os
|
|
import sys
|
|
import argparse
|
|
import logging
|
|
import json
|
|
import subprocess
|
|
from pathlib import Path
|
|
from datetime import datetime
|
|
from typing import Dict, List, Optional, Any
|
|
|
|
# Add project root to path
|
|
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
|
|
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class ValidationOrchestrator:
|
|
"""Orchestrates all validation activities through a single interface"""
|
|
|
|
def __init__(self, output_dir: str = "./validation_results"):
|
|
self.output_dir = Path(output_dir)
|
|
self.output_dir.mkdir(parents=True, exist_ok=True)
|
|
self.results = {}
|
|
|
|
def run_quick_validation(self, args) -> Dict[str, Any]:
|
|
"""Run quick validation using the existing quick_validation.py script"""
|
|
logger.info("🏃♂️ Running Quick Validation (5 minutes)")
|
|
|
|
cmd = ["python", "scripts/quick_validation.py"]
|
|
|
|
if args.retriever_model:
|
|
cmd.extend(["--retriever_model", args.retriever_model])
|
|
if args.reranker_model:
|
|
cmd.extend(["--reranker_model", args.reranker_model])
|
|
if args.retriever_model and args.reranker_model:
|
|
cmd.append("--test_pipeline")
|
|
|
|
# Save output to file
|
|
output_file = self.output_dir / "quick_validation.json"
|
|
cmd.extend(["--output_file", str(output_file)])
|
|
|
|
try:
|
|
result = subprocess.run(cmd, check=True, capture_output=True, text=True)
|
|
logger.info("✅ Quick validation completed successfully")
|
|
|
|
# Load results if available
|
|
if output_file.exists():
|
|
with open(output_file, 'r') as f:
|
|
return json.load(f)
|
|
return {"status": "completed", "output": result.stdout}
|
|
|
|
except subprocess.CalledProcessError as e:
|
|
logger.error(f"❌ Quick validation failed: {e.stderr}")
|
|
return {"status": "failed", "error": str(e)}
|
|
|
|
def run_comprehensive_validation(self, args) -> Dict[str, Any]:
|
|
"""Run comprehensive validation"""
|
|
logger.info("🔬 Running Comprehensive Validation (30 minutes)")
|
|
|
|
cmd = ["python", "scripts/comprehensive_validation.py"]
|
|
|
|
if args.retriever_model:
|
|
cmd.extend(["--retriever_finetuned", args.retriever_model])
|
|
if args.reranker_model:
|
|
cmd.extend(["--reranker_finetuned", args.reranker_model])
|
|
|
|
# Add test datasets
|
|
if args.test_data_dir:
|
|
test_dir = Path(args.test_data_dir)
|
|
test_datasets = []
|
|
if (test_dir / "m3_test.jsonl").exists():
|
|
test_datasets.append(str(test_dir / "m3_test.jsonl"))
|
|
if (test_dir / "reranker_test.jsonl").exists():
|
|
test_datasets.append(str(test_dir / "reranker_test.jsonl"))
|
|
|
|
if test_datasets:
|
|
cmd.extend(["--test_datasets"] + test_datasets)
|
|
|
|
# Configuration
|
|
cmd.extend([
|
|
"--output_dir", str(self.output_dir / "comprehensive"),
|
|
"--batch_size", str(args.batch_size),
|
|
"--k_values", "1", "3", "5", "10", "20"
|
|
])
|
|
|
|
try:
|
|
result = subprocess.run(cmd, check=True, capture_output=True, text=True)
|
|
logger.info("✅ Comprehensive validation completed successfully")
|
|
|
|
# Load results if available
|
|
results_file = self.output_dir / "comprehensive" / "validation_results.json"
|
|
if results_file.exists():
|
|
with open(results_file, 'r') as f:
|
|
return json.load(f)
|
|
return {"status": "completed", "output": result.stdout}
|
|
|
|
except subprocess.CalledProcessError as e:
|
|
logger.error(f"❌ Comprehensive validation failed: {e.stderr}")
|
|
return {"status": "failed", "error": str(e)}
|
|
|
|
def run_comparison(self, args) -> Dict[str, Any]:
|
|
"""Run model comparison against baselines"""
|
|
logger.info("⚖️ Running Model Comparison")
|
|
|
|
results = {}
|
|
|
|
# Compare retriever if available
|
|
if args.retriever_model and args.retriever_data:
|
|
cmd = [
|
|
"python", "scripts/compare_models.py",
|
|
"--model_type", "retriever",
|
|
"--finetuned_model", args.retriever_model,
|
|
"--baseline_model", args.baseline_retriever or "BAAI/bge-m3",
|
|
"--data_path", args.retriever_data,
|
|
"--batch_size", str(args.batch_size),
|
|
"--output_dir", str(self.output_dir / "comparison")
|
|
]
|
|
|
|
if args.max_samples:
|
|
cmd.extend(["--max_samples", str(args.max_samples)])
|
|
|
|
try:
|
|
subprocess.run(cmd, check=True, capture_output=True, text=True)
|
|
logger.info("✅ Retriever comparison completed")
|
|
|
|
# Load results
|
|
results_file = self.output_dir / "comparison" / "retriever_comparison.json"
|
|
if results_file.exists():
|
|
with open(results_file, 'r') as f:
|
|
results['retriever'] = json.load(f)
|
|
|
|
except subprocess.CalledProcessError as e:
|
|
logger.error(f"❌ Retriever comparison failed: {e.stderr}")
|
|
results['retriever'] = {"status": "failed", "error": str(e)}
|
|
|
|
# Compare reranker if available
|
|
if args.reranker_model and args.reranker_data:
|
|
cmd = [
|
|
"python", "scripts/compare_models.py",
|
|
"--model_type", "reranker",
|
|
"--finetuned_model", args.reranker_model,
|
|
"--baseline_model", args.baseline_reranker or "BAAI/bge-reranker-base",
|
|
"--data_path", args.reranker_data,
|
|
"--batch_size", str(args.batch_size),
|
|
"--output_dir", str(self.output_dir / "comparison")
|
|
]
|
|
|
|
if args.max_samples:
|
|
cmd.extend(["--max_samples", str(args.max_samples)])
|
|
|
|
try:
|
|
subprocess.run(cmd, check=True, capture_output=True, text=True)
|
|
logger.info("✅ Reranker comparison completed")
|
|
|
|
# Load results
|
|
results_file = self.output_dir / "comparison" / "reranker_comparison.json"
|
|
if results_file.exists():
|
|
with open(results_file, 'r') as f:
|
|
results['reranker'] = json.load(f)
|
|
|
|
except subprocess.CalledProcessError as e:
|
|
logger.error(f"❌ Reranker comparison failed: {e.stderr}")
|
|
results['reranker'] = {"status": "failed", "error": str(e)}
|
|
|
|
return results
|
|
|
|
def run_benchmark(self, args) -> Dict[str, Any]:
|
|
"""Run performance benchmarking"""
|
|
logger.info("⚡ Running Performance Benchmarking")
|
|
|
|
results = {}
|
|
|
|
# Benchmark retriever
|
|
if args.retriever_model:
|
|
retriever_data = args.retriever_data or self._find_test_data(args.test_data_dir, "m3_test.jsonl")
|
|
if retriever_data:
|
|
cmd = [
|
|
"python", "scripts/benchmark.py",
|
|
"--model_type", "retriever",
|
|
"--model_path", args.retriever_model,
|
|
"--data_path", retriever_data,
|
|
"--batch_size", str(args.batch_size),
|
|
"--output", str(self.output_dir / "retriever_benchmark.txt")
|
|
]
|
|
|
|
if args.max_samples:
|
|
cmd.extend(["--max_samples", str(args.max_samples)])
|
|
|
|
try:
|
|
result = subprocess.run(cmd, check=True, capture_output=True, text=True)
|
|
logger.info("✅ Retriever benchmarking completed")
|
|
results['retriever'] = {"status": "completed", "output": result.stdout}
|
|
except subprocess.CalledProcessError as e:
|
|
logger.error(f"❌ Retriever benchmarking failed: {e.stderr}")
|
|
results['retriever'] = {"status": "failed", "error": str(e)}
|
|
|
|
# Benchmark reranker
|
|
if args.reranker_model:
|
|
reranker_data = args.reranker_data or self._find_test_data(args.test_data_dir, "reranker_test.jsonl")
|
|
if reranker_data:
|
|
cmd = [
|
|
"python", "scripts/benchmark.py",
|
|
"--model_type", "reranker",
|
|
"--model_path", args.reranker_model,
|
|
"--data_path", reranker_data,
|
|
"--batch_size", str(args.batch_size),
|
|
"--output", str(self.output_dir / "reranker_benchmark.txt")
|
|
]
|
|
|
|
if args.max_samples:
|
|
cmd.extend(["--max_samples", str(args.max_samples)])
|
|
|
|
try:
|
|
result = subprocess.run(cmd, check=True, capture_output=True, text=True)
|
|
logger.info("✅ Reranker benchmarking completed")
|
|
results['reranker'] = {"status": "completed", "output": result.stdout}
|
|
except subprocess.CalledProcessError as e:
|
|
logger.error(f"❌ Reranker benchmarking failed: {e.stderr}")
|
|
results['reranker'] = {"status": "failed", "error": str(e)}
|
|
|
|
return results
|
|
|
|
def run_all_validation(self, args) -> Dict[str, Any]:
|
|
"""Run complete validation suite"""
|
|
logger.info("🎯 Running Complete Validation Suite")
|
|
|
|
all_results = {
|
|
'start_time': datetime.now().isoformat(),
|
|
'validation_modes': []
|
|
}
|
|
|
|
# 1. Quick validation
|
|
logger.info("\n" + "="*60)
|
|
logger.info("PHASE 1: QUICK VALIDATION")
|
|
logger.info("="*60)
|
|
all_results['quick'] = self.run_quick_validation(args)
|
|
all_results['validation_modes'].append('quick')
|
|
|
|
# 2. Comprehensive validation
|
|
logger.info("\n" + "="*60)
|
|
logger.info("PHASE 2: COMPREHENSIVE VALIDATION")
|
|
logger.info("="*60)
|
|
all_results['comprehensive'] = self.run_comprehensive_validation(args)
|
|
all_results['validation_modes'].append('comprehensive')
|
|
|
|
# 3. Comparison validation
|
|
if self._has_comparison_data(args):
|
|
logger.info("\n" + "="*60)
|
|
logger.info("PHASE 3: BASELINE COMPARISON")
|
|
logger.info("="*60)
|
|
all_results['comparison'] = self.run_comparison(args)
|
|
all_results['validation_modes'].append('comparison')
|
|
|
|
# 4. Performance benchmarking
|
|
logger.info("\n" + "="*60)
|
|
logger.info("PHASE 4: PERFORMANCE BENCHMARKING")
|
|
logger.info("="*60)
|
|
all_results['benchmark'] = self.run_benchmark(args)
|
|
all_results['validation_modes'].append('benchmark')
|
|
|
|
all_results['end_time'] = datetime.now().isoformat()
|
|
|
|
# Generate summary report
|
|
self._generate_summary_report(all_results)
|
|
|
|
return all_results
|
|
|
|
def _find_test_data(self, test_data_dir: Optional[str], filename: str) -> Optional[str]:
|
|
"""Find test data file in the specified directory"""
|
|
if not test_data_dir:
|
|
return None
|
|
|
|
test_path = Path(test_data_dir) / filename
|
|
return str(test_path) if test_path.exists() else None
|
|
|
|
def _has_comparison_data(self, args) -> bool:
|
|
"""Check if we have data needed for comparison"""
|
|
has_retriever_data = bool(args.retriever_data or self._find_test_data(args.test_data_dir, "m3_test.jsonl"))
|
|
has_reranker_data = bool(args.reranker_data or self._find_test_data(args.test_data_dir, "reranker_test.jsonl"))
|
|
|
|
return (args.retriever_model and has_retriever_data) or (args.reranker_model and has_reranker_data)
|
|
|
|
def _generate_summary_report(self, results: Dict[str, Any]):
|
|
"""Generate a comprehensive summary report"""
|
|
logger.info("\n" + "="*80)
|
|
logger.info("📋 VALIDATION SUITE SUMMARY REPORT")
|
|
logger.info("="*80)
|
|
|
|
report_lines = []
|
|
report_lines.append("# BGE Model Validation Report")
|
|
report_lines.append(f"Generated: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
|
|
report_lines.append("")
|
|
|
|
# Summary
|
|
completed_phases = len(results['validation_modes'])
|
|
total_phases = 4
|
|
success_rate = completed_phases / total_phases * 100
|
|
|
|
report_lines.append(f"## Overall Summary")
|
|
report_lines.append(f"- Validation phases completed: {completed_phases}/{total_phases} ({success_rate:.1f}%)")
|
|
report_lines.append(f"- Start time: {results['start_time']}")
|
|
report_lines.append(f"- End time: {results['end_time']}")
|
|
report_lines.append("")
|
|
|
|
# Phase results
|
|
for phase in ['quick', 'comprehensive', 'comparison', 'benchmark']:
|
|
if phase in results:
|
|
phase_result = results[phase]
|
|
status = "✅ SUCCESS" if phase_result.get('status') != 'failed' else "❌ FAILED"
|
|
report_lines.append(f"### {phase.title()} Validation: {status}")
|
|
|
|
if phase == 'comparison' and isinstance(phase_result, dict):
|
|
if 'retriever' in phase_result:
|
|
retriever_summary = phase_result['retriever'].get('summary', {})
|
|
if 'status' in retriever_summary:
|
|
report_lines.append(f" - Retriever: {retriever_summary['status']}")
|
|
if 'reranker' in phase_result:
|
|
reranker_summary = phase_result['reranker'].get('summary', {})
|
|
if 'status' in reranker_summary:
|
|
report_lines.append(f" - Reranker: {reranker_summary['status']}")
|
|
|
|
report_lines.append("")
|
|
|
|
# Recommendations
|
|
report_lines.append("## Recommendations")
|
|
|
|
if success_rate == 100:
|
|
report_lines.append("- 🌟 All validation phases completed successfully!")
|
|
report_lines.append("- Your models are ready for production deployment.")
|
|
elif success_rate >= 75:
|
|
report_lines.append("- ✅ Most validation phases successful.")
|
|
report_lines.append("- Review failed phases and address any issues.")
|
|
elif success_rate >= 50:
|
|
report_lines.append("- ⚠️ Partial validation success.")
|
|
report_lines.append("- Significant issues detected. Review logs carefully.")
|
|
else:
|
|
report_lines.append("- ❌ Multiple validation failures.")
|
|
report_lines.append("- Major issues with model training or configuration.")
|
|
|
|
report_lines.append("")
|
|
report_lines.append("## Files Generated")
|
|
report_lines.append(f"- Summary report: {self.output_dir}/validation_summary.md")
|
|
report_lines.append(f"- Detailed results: {self.output_dir}/validation_results.json")
|
|
if (self.output_dir / "comparison").exists():
|
|
report_lines.append(f"- Comparison results: {self.output_dir}/comparison/")
|
|
if (self.output_dir / "comprehensive").exists():
|
|
report_lines.append(f"- Comprehensive results: {self.output_dir}/comprehensive/")
|
|
|
|
# Save report
|
|
with open(self.output_dir / "validation_summary.md", 'w', encoding='utf-8') as f:
|
|
f.write('\n'.join(report_lines))
|
|
|
|
with open(self.output_dir / "validation_results.json", 'w', encoding='utf-8') as f:
|
|
json.dump(results, f, indent=2, default=str)
|
|
|
|
# Print summary to console
|
|
print("\n".join(report_lines))
|
|
|
|
logger.info(f"📋 Summary report saved to: {self.output_dir}/validation_summary.md")
|
|
logger.info(f"📋 Detailed results saved to: {self.output_dir}/validation_results.json")
|
|
|
|
|
|
def parse_args():
|
|
parser = argparse.ArgumentParser(
|
|
description="BGE Model Validation - Single Entry Point",
|
|
formatter_class=argparse.RawDescriptionHelpFormatter,
|
|
epilog=__doc__
|
|
)
|
|
|
|
# Validation mode (required)
|
|
parser.add_argument(
|
|
"mode",
|
|
choices=['quick', 'comprehensive', 'compare', 'benchmark', 'all'],
|
|
help="Validation mode to run"
|
|
)
|
|
|
|
# Model paths
|
|
parser.add_argument("--retriever_model", type=str, help="Path to fine-tuned retriever model")
|
|
parser.add_argument("--reranker_model", type=str, help="Path to fine-tuned reranker model")
|
|
|
|
# Data paths
|
|
parser.add_argument("--test_data_dir", type=str, help="Directory containing test datasets")
|
|
parser.add_argument("--retriever_data", type=str, help="Path to retriever test data")
|
|
parser.add_argument("--reranker_data", type=str, help="Path to reranker test data")
|
|
|
|
# Baseline models for comparison
|
|
parser.add_argument("--baseline_retriever", type=str, default="BAAI/bge-m3",
|
|
help="Baseline retriever model")
|
|
parser.add_argument("--baseline_reranker", type=str, default="BAAI/bge-reranker-base",
|
|
help="Baseline reranker model")
|
|
|
|
# Configuration
|
|
parser.add_argument("--batch_size", type=int, default=32, help="Batch size for evaluation")
|
|
parser.add_argument("--max_samples", type=int, help="Maximum samples to test (for speed)")
|
|
parser.add_argument("--output_dir", type=str, default="./validation_results",
|
|
help="Directory to save results")
|
|
|
|
return parser.parse_args()
|
|
|
|
|
|
def main():
|
|
args = parse_args()
|
|
|
|
# Validate arguments
|
|
if not args.retriever_model and not args.reranker_model:
|
|
logger.error("❌ At least one model (--retriever_model or --reranker_model) is required")
|
|
return 1
|
|
|
|
logger.info("🚀 BGE Model Validation Started")
|
|
logger.info(f"Mode: {args.mode}")
|
|
logger.info(f"Output directory: {args.output_dir}")
|
|
|
|
# Initialize orchestrator
|
|
orchestrator = ValidationOrchestrator(args.output_dir)
|
|
|
|
# Run validation based on mode
|
|
try:
|
|
if args.mode == 'quick':
|
|
results = orchestrator.run_quick_validation(args)
|
|
elif args.mode == 'comprehensive':
|
|
results = orchestrator.run_comprehensive_validation(args)
|
|
elif args.mode == 'compare':
|
|
results = orchestrator.run_comparison(args)
|
|
elif args.mode == 'benchmark':
|
|
results = orchestrator.run_benchmark(args)
|
|
elif args.mode == 'all':
|
|
results = orchestrator.run_all_validation(args)
|
|
else:
|
|
logger.error(f"Unknown validation mode: {args.mode}")
|
|
return 1
|
|
|
|
logger.info("✅ Validation completed successfully!")
|
|
logger.info(f"📁 Results saved to: {args.output_dir}")
|
|
|
|
return 0
|
|
|
|
except Exception as e:
|
|
logger.error(f"❌ Validation failed: {e}")
|
|
return 1
|
|
|
|
|
|
if __name__ == "__main__":
|
|
exit(main()) |