bge_finetune/scripts/validate.py
2025-07-23 14:54:46 +08:00

483 lines
21 KiB
Python

#!/usr/bin/env python3
"""
BGE Model Validation - Single Entry Point
This script provides a unified, streamlined interface for all BGE model validation needs.
It replaces the multiple confusing validation scripts with one clear, easy-to-use tool.
Usage Examples:
# Quick validation (5 minutes)
python scripts/validate.py quick \
--retriever_model ./output/bge_m3_三国演义/final_model \
--reranker_model ./output/bge_reranker_三国演义/final_model
# Comprehensive validation (30 minutes)
python scripts/validate.py comprehensive \
--retriever_model ./output/bge_m3_三国演义/final_model \
--reranker_model ./output/bge_reranker_三国演义/final_model \
--test_data_dir "data/datasets/三国演义/splits"
# Compare against baselines
python scripts/validate.py compare \
--retriever_model ./output/bge_m3_三国演义/final_model \
--reranker_model ./output/bge_reranker_三国演义/final_model \
--retriever_data "data/datasets/三国演义/splits/m3_test.jsonl" \
--reranker_data "data/datasets/三国演义/splits/reranker_test.jsonl"
# Performance benchmarking only
python scripts/validate.py benchmark \
--retriever_model ./output/bge_m3_三国演义/final_model \
--reranker_model ./output/bge_reranker_三国演义/final_model
# Complete validation suite (includes all above)
python scripts/validate.py all \
--retriever_model ./output/bge_m3_三国演义/final_model \
--reranker_model ./output/bge_reranker_三国演义/final_model \
--test_data_dir "data/datasets/三国演义/splits"
"""
import os
import sys
import argparse
import logging
import json
import subprocess
from pathlib import Path
from datetime import datetime
from typing import Dict, List, Optional, Any
# Add project root to path
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
class ValidationOrchestrator:
"""Orchestrates all validation activities through a single interface"""
def __init__(self, output_dir: str = "./validation_results"):
self.output_dir = Path(output_dir)
self.output_dir.mkdir(parents=True, exist_ok=True)
self.results = {}
def run_quick_validation(self, args) -> Dict[str, Any]:
"""Run quick validation using the existing quick_validation.py script"""
logger.info("🏃‍♂️ Running Quick Validation (5 minutes)")
cmd = ["python", "scripts/quick_validation.py"]
if args.retriever_model:
cmd.extend(["--retriever_model", args.retriever_model])
if args.reranker_model:
cmd.extend(["--reranker_model", args.reranker_model])
if args.retriever_model and args.reranker_model:
cmd.append("--test_pipeline")
# Save output to file
output_file = self.output_dir / "quick_validation.json"
cmd.extend(["--output_file", str(output_file)])
try:
result = subprocess.run(cmd, check=True, capture_output=True, text=True)
logger.info("✅ Quick validation completed successfully")
# Load results if available
if output_file.exists():
with open(output_file, 'r') as f:
return json.load(f)
return {"status": "completed", "output": result.stdout}
except subprocess.CalledProcessError as e:
logger.error(f"❌ Quick validation failed: {e.stderr}")
return {"status": "failed", "error": str(e)}
def run_comprehensive_validation(self, args) -> Dict[str, Any]:
"""Run comprehensive validation"""
logger.info("🔬 Running Comprehensive Validation (30 minutes)")
cmd = ["python", "scripts/comprehensive_validation.py"]
if args.retriever_model:
cmd.extend(["--retriever_finetuned", args.retriever_model])
if args.reranker_model:
cmd.extend(["--reranker_finetuned", args.reranker_model])
# Add test datasets
if args.test_data_dir:
test_dir = Path(args.test_data_dir)
test_datasets = []
if (test_dir / "m3_test.jsonl").exists():
test_datasets.append(str(test_dir / "m3_test.jsonl"))
if (test_dir / "reranker_test.jsonl").exists():
test_datasets.append(str(test_dir / "reranker_test.jsonl"))
if test_datasets:
cmd.extend(["--test_datasets"] + test_datasets)
# Configuration
cmd.extend([
"--output_dir", str(self.output_dir / "comprehensive"),
"--batch_size", str(args.batch_size),
"--k_values", "1", "3", "5", "10", "20"
])
try:
result = subprocess.run(cmd, check=True, capture_output=True, text=True)
logger.info("✅ Comprehensive validation completed successfully")
# Load results if available
results_file = self.output_dir / "comprehensive" / "validation_results.json"
if results_file.exists():
with open(results_file, 'r') as f:
return json.load(f)
return {"status": "completed", "output": result.stdout}
except subprocess.CalledProcessError as e:
logger.error(f"❌ Comprehensive validation failed: {e.stderr}")
return {"status": "failed", "error": str(e)}
def run_comparison(self, args) -> Dict[str, Any]:
"""Run model comparison against baselines"""
logger.info("⚖️ Running Model Comparison")
results = {}
# Compare retriever if available
if args.retriever_model and args.retriever_data:
cmd = [
"python", "scripts/compare_models.py",
"--model_type", "retriever",
"--finetuned_model", args.retriever_model,
"--baseline_model", args.baseline_retriever or "BAAI/bge-m3",
"--data_path", args.retriever_data,
"--batch_size", str(args.batch_size),
"--output_dir", str(self.output_dir / "comparison")
]
if args.max_samples:
cmd.extend(["--max_samples", str(args.max_samples)])
try:
subprocess.run(cmd, check=True, capture_output=True, text=True)
logger.info("✅ Retriever comparison completed")
# Load results
results_file = self.output_dir / "comparison" / "retriever_comparison.json"
if results_file.exists():
with open(results_file, 'r') as f:
results['retriever'] = json.load(f)
except subprocess.CalledProcessError as e:
logger.error(f"❌ Retriever comparison failed: {e.stderr}")
results['retriever'] = {"status": "failed", "error": str(e)}
# Compare reranker if available
if args.reranker_model and args.reranker_data:
cmd = [
"python", "scripts/compare_models.py",
"--model_type", "reranker",
"--finetuned_model", args.reranker_model,
"--baseline_model", args.baseline_reranker or "BAAI/bge-reranker-base",
"--data_path", args.reranker_data,
"--batch_size", str(args.batch_size),
"--output_dir", str(self.output_dir / "comparison")
]
if args.max_samples:
cmd.extend(["--max_samples", str(args.max_samples)])
try:
subprocess.run(cmd, check=True, capture_output=True, text=True)
logger.info("✅ Reranker comparison completed")
# Load results
results_file = self.output_dir / "comparison" / "reranker_comparison.json"
if results_file.exists():
with open(results_file, 'r') as f:
results['reranker'] = json.load(f)
except subprocess.CalledProcessError as e:
logger.error(f"❌ Reranker comparison failed: {e.stderr}")
results['reranker'] = {"status": "failed", "error": str(e)}
return results
def run_benchmark(self, args) -> Dict[str, Any]:
"""Run performance benchmarking"""
logger.info("⚡ Running Performance Benchmarking")
results = {}
# Benchmark retriever
if args.retriever_model:
retriever_data = args.retriever_data or self._find_test_data(args.test_data_dir, "m3_test.jsonl")
if retriever_data:
cmd = [
"python", "scripts/benchmark.py",
"--model_type", "retriever",
"--model_path", args.retriever_model,
"--data_path", retriever_data,
"--batch_size", str(args.batch_size),
"--output", str(self.output_dir / "retriever_benchmark.txt")
]
if args.max_samples:
cmd.extend(["--max_samples", str(args.max_samples)])
try:
result = subprocess.run(cmd, check=True, capture_output=True, text=True)
logger.info("✅ Retriever benchmarking completed")
results['retriever'] = {"status": "completed", "output": result.stdout}
except subprocess.CalledProcessError as e:
logger.error(f"❌ Retriever benchmarking failed: {e.stderr}")
results['retriever'] = {"status": "failed", "error": str(e)}
# Benchmark reranker
if args.reranker_model:
reranker_data = args.reranker_data or self._find_test_data(args.test_data_dir, "reranker_test.jsonl")
if reranker_data:
cmd = [
"python", "scripts/benchmark.py",
"--model_type", "reranker",
"--model_path", args.reranker_model,
"--data_path", reranker_data,
"--batch_size", str(args.batch_size),
"--output", str(self.output_dir / "reranker_benchmark.txt")
]
if args.max_samples:
cmd.extend(["--max_samples", str(args.max_samples)])
try:
result = subprocess.run(cmd, check=True, capture_output=True, text=True)
logger.info("✅ Reranker benchmarking completed")
results['reranker'] = {"status": "completed", "output": result.stdout}
except subprocess.CalledProcessError as e:
logger.error(f"❌ Reranker benchmarking failed: {e.stderr}")
results['reranker'] = {"status": "failed", "error": str(e)}
return results
def run_all_validation(self, args) -> Dict[str, Any]:
"""Run complete validation suite"""
logger.info("🎯 Running Complete Validation Suite")
all_results = {
'start_time': datetime.now().isoformat(),
'validation_modes': []
}
# 1. Quick validation
logger.info("\n" + "="*60)
logger.info("PHASE 1: QUICK VALIDATION")
logger.info("="*60)
all_results['quick'] = self.run_quick_validation(args)
all_results['validation_modes'].append('quick')
# 2. Comprehensive validation
logger.info("\n" + "="*60)
logger.info("PHASE 2: COMPREHENSIVE VALIDATION")
logger.info("="*60)
all_results['comprehensive'] = self.run_comprehensive_validation(args)
all_results['validation_modes'].append('comprehensive')
# 3. Comparison validation
if self._has_comparison_data(args):
logger.info("\n" + "="*60)
logger.info("PHASE 3: BASELINE COMPARISON")
logger.info("="*60)
all_results['comparison'] = self.run_comparison(args)
all_results['validation_modes'].append('comparison')
# 4. Performance benchmarking
logger.info("\n" + "="*60)
logger.info("PHASE 4: PERFORMANCE BENCHMARKING")
logger.info("="*60)
all_results['benchmark'] = self.run_benchmark(args)
all_results['validation_modes'].append('benchmark')
all_results['end_time'] = datetime.now().isoformat()
# Generate summary report
self._generate_summary_report(all_results)
return all_results
def _find_test_data(self, test_data_dir: Optional[str], filename: str) -> Optional[str]:
"""Find test data file in the specified directory"""
if not test_data_dir:
return None
test_path = Path(test_data_dir) / filename
return str(test_path) if test_path.exists() else None
def _has_comparison_data(self, args) -> bool:
"""Check if we have data needed for comparison"""
has_retriever_data = bool(args.retriever_data or self._find_test_data(args.test_data_dir, "m3_test.jsonl"))
has_reranker_data = bool(args.reranker_data or self._find_test_data(args.test_data_dir, "reranker_test.jsonl"))
return (args.retriever_model and has_retriever_data) or (args.reranker_model and has_reranker_data)
def _generate_summary_report(self, results: Dict[str, Any]):
"""Generate a comprehensive summary report"""
logger.info("\n" + "="*80)
logger.info("📋 VALIDATION SUITE SUMMARY REPORT")
logger.info("="*80)
report_lines = []
report_lines.append("# BGE Model Validation Report")
report_lines.append(f"Generated: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
report_lines.append("")
# Summary
completed_phases = len(results['validation_modes'])
total_phases = 4
success_rate = completed_phases / total_phases * 100
report_lines.append(f"## Overall Summary")
report_lines.append(f"- Validation phases completed: {completed_phases}/{total_phases} ({success_rate:.1f}%)")
report_lines.append(f"- Start time: {results['start_time']}")
report_lines.append(f"- End time: {results['end_time']}")
report_lines.append("")
# Phase results
for phase in ['quick', 'comprehensive', 'comparison', 'benchmark']:
if phase in results:
phase_result = results[phase]
status = "✅ SUCCESS" if phase_result.get('status') != 'failed' else "❌ FAILED"
report_lines.append(f"### {phase.title()} Validation: {status}")
if phase == 'comparison' and isinstance(phase_result, dict):
if 'retriever' in phase_result:
retriever_summary = phase_result['retriever'].get('summary', {})
if 'status' in retriever_summary:
report_lines.append(f" - Retriever: {retriever_summary['status']}")
if 'reranker' in phase_result:
reranker_summary = phase_result['reranker'].get('summary', {})
if 'status' in reranker_summary:
report_lines.append(f" - Reranker: {reranker_summary['status']}")
report_lines.append("")
# Recommendations
report_lines.append("## Recommendations")
if success_rate == 100:
report_lines.append("- 🌟 All validation phases completed successfully!")
report_lines.append("- Your models are ready for production deployment.")
elif success_rate >= 75:
report_lines.append("- ✅ Most validation phases successful.")
report_lines.append("- Review failed phases and address any issues.")
elif success_rate >= 50:
report_lines.append("- ⚠️ Partial validation success.")
report_lines.append("- Significant issues detected. Review logs carefully.")
else:
report_lines.append("- ❌ Multiple validation failures.")
report_lines.append("- Major issues with model training or configuration.")
report_lines.append("")
report_lines.append("## Files Generated")
report_lines.append(f"- Summary report: {self.output_dir}/validation_summary.md")
report_lines.append(f"- Detailed results: {self.output_dir}/validation_results.json")
if (self.output_dir / "comparison").exists():
report_lines.append(f"- Comparison results: {self.output_dir}/comparison/")
if (self.output_dir / "comprehensive").exists():
report_lines.append(f"- Comprehensive results: {self.output_dir}/comprehensive/")
# Save report
with open(self.output_dir / "validation_summary.md", 'w', encoding='utf-8') as f:
f.write('\n'.join(report_lines))
with open(self.output_dir / "validation_results.json", 'w', encoding='utf-8') as f:
json.dump(results, f, indent=2, default=str)
# Print summary to console
print("\n".join(report_lines))
logger.info(f"📋 Summary report saved to: {self.output_dir}/validation_summary.md")
logger.info(f"📋 Detailed results saved to: {self.output_dir}/validation_results.json")
def parse_args():
parser = argparse.ArgumentParser(
description="BGE Model Validation - Single Entry Point",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog=__doc__
)
# Validation mode (required)
parser.add_argument(
"mode",
choices=['quick', 'comprehensive', 'compare', 'benchmark', 'all'],
help="Validation mode to run"
)
# Model paths
parser.add_argument("--retriever_model", type=str, help="Path to fine-tuned retriever model")
parser.add_argument("--reranker_model", type=str, help="Path to fine-tuned reranker model")
# Data paths
parser.add_argument("--test_data_dir", type=str, help="Directory containing test datasets")
parser.add_argument("--retriever_data", type=str, help="Path to retriever test data")
parser.add_argument("--reranker_data", type=str, help="Path to reranker test data")
# Baseline models for comparison
parser.add_argument("--baseline_retriever", type=str, default="BAAI/bge-m3",
help="Baseline retriever model")
parser.add_argument("--baseline_reranker", type=str, default="BAAI/bge-reranker-base",
help="Baseline reranker model")
# Configuration
parser.add_argument("--batch_size", type=int, default=32, help="Batch size for evaluation")
parser.add_argument("--max_samples", type=int, help="Maximum samples to test (for speed)")
parser.add_argument("--output_dir", type=str, default="./validation_results",
help="Directory to save results")
return parser.parse_args()
def main():
args = parse_args()
# Validate arguments
if not args.retriever_model and not args.reranker_model:
logger.error("❌ At least one model (--retriever_model or --reranker_model) is required")
return 1
logger.info("🚀 BGE Model Validation Started")
logger.info(f"Mode: {args.mode}")
logger.info(f"Output directory: {args.output_dir}")
# Initialize orchestrator
orchestrator = ValidationOrchestrator(args.output_dir)
# Run validation based on mode
try:
if args.mode == 'quick':
results = orchestrator.run_quick_validation(args)
elif args.mode == 'comprehensive':
results = orchestrator.run_comprehensive_validation(args)
elif args.mode == 'compare':
results = orchestrator.run_comparison(args)
elif args.mode == 'benchmark':
results = orchestrator.run_benchmark(args)
elif args.mode == 'all':
results = orchestrator.run_all_validation(args)
else:
logger.error(f"Unknown validation mode: {args.mode}")
return 1
logger.info("✅ Validation completed successfully!")
logger.info(f"📁 Results saved to: {args.output_dir}")
return 0
except Exception as e:
logger.error(f"❌ Validation failed: {e}")
return 1
if __name__ == "__main__":
exit(main())