#!/usr/bin/env python3 """ BGE Model Validation - Single Entry Point This script provides a unified, streamlined interface for all BGE model validation needs. It replaces the multiple confusing validation scripts with one clear, easy-to-use tool. Usage Examples: # Quick validation (5 minutes) python scripts/validate.py quick \ --retriever_model ./output/bge_m3_三国演义/final_model \ --reranker_model ./output/bge_reranker_三国演义/final_model # Comprehensive validation (30 minutes) python scripts/validate.py comprehensive \ --retriever_model ./output/bge_m3_三国演义/final_model \ --reranker_model ./output/bge_reranker_三国演义/final_model \ --test_data_dir "data/datasets/三国演义/splits" # Compare against baselines python scripts/validate.py compare \ --retriever_model ./output/bge_m3_三国演义/final_model \ --reranker_model ./output/bge_reranker_三国演义/final_model \ --retriever_data "data/datasets/三国演义/splits/m3_test.jsonl" \ --reranker_data "data/datasets/三国演义/splits/reranker_test.jsonl" # Performance benchmarking only python scripts/validate.py benchmark \ --retriever_model ./output/bge_m3_三国演义/final_model \ --reranker_model ./output/bge_reranker_三国演义/final_model # Complete validation suite (includes all above) python scripts/validate.py all \ --retriever_model ./output/bge_m3_三国演义/final_model \ --reranker_model ./output/bge_reranker_三国演义/final_model \ --test_data_dir "data/datasets/三国演义/splits" """ import os import sys import argparse import logging import json import subprocess from pathlib import Path from datetime import datetime from typing import Dict, List, Optional, Any # Add project root to path sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') logger = logging.getLogger(__name__) class ValidationOrchestrator: """Orchestrates all validation activities through a single interface""" def __init__(self, output_dir: str = "./validation_results"): self.output_dir = Path(output_dir) self.output_dir.mkdir(parents=True, exist_ok=True) self.results = {} def run_quick_validation(self, args) -> Dict[str, Any]: """Run quick validation using the existing quick_validation.py script""" logger.info("🏃‍♂️ Running Quick Validation (5 minutes)") cmd = ["python", "scripts/quick_validation.py"] if args.retriever_model: cmd.extend(["--retriever_model", args.retriever_model]) if args.reranker_model: cmd.extend(["--reranker_model", args.reranker_model]) if args.retriever_model and args.reranker_model: cmd.append("--test_pipeline") # Save output to file output_file = self.output_dir / "quick_validation.json" cmd.extend(["--output_file", str(output_file)]) try: result = subprocess.run(cmd, check=True, capture_output=True, text=True) logger.info("✅ Quick validation completed successfully") # Load results if available if output_file.exists(): with open(output_file, 'r') as f: return json.load(f) return {"status": "completed", "output": result.stdout} except subprocess.CalledProcessError as e: logger.error(f"❌ Quick validation failed: {e.stderr}") return {"status": "failed", "error": str(e)} def run_comprehensive_validation(self, args) -> Dict[str, Any]: """Run comprehensive validation""" logger.info("🔬 Running Comprehensive Validation (30 minutes)") cmd = ["python", "scripts/comprehensive_validation.py"] if args.retriever_model: cmd.extend(["--retriever_finetuned", args.retriever_model]) if args.reranker_model: cmd.extend(["--reranker_finetuned", args.reranker_model]) # Add test datasets if args.test_data_dir: test_dir = Path(args.test_data_dir) test_datasets = [] if (test_dir / "m3_test.jsonl").exists(): test_datasets.append(str(test_dir / "m3_test.jsonl")) if (test_dir / "reranker_test.jsonl").exists(): test_datasets.append(str(test_dir / "reranker_test.jsonl")) if test_datasets: cmd.extend(["--test_datasets"] + test_datasets) # Configuration cmd.extend([ "--output_dir", str(self.output_dir / "comprehensive"), "--batch_size", str(args.batch_size), "--k_values", "1", "3", "5", "10", "20" ]) try: result = subprocess.run(cmd, check=True, capture_output=True, text=True) logger.info("✅ Comprehensive validation completed successfully") # Load results if available results_file = self.output_dir / "comprehensive" / "validation_results.json" if results_file.exists(): with open(results_file, 'r') as f: return json.load(f) return {"status": "completed", "output": result.stdout} except subprocess.CalledProcessError as e: logger.error(f"❌ Comprehensive validation failed: {e.stderr}") return {"status": "failed", "error": str(e)} def run_comparison(self, args) -> Dict[str, Any]: """Run model comparison against baselines""" logger.info("⚖️ Running Model Comparison") results = {} # Compare retriever if available if args.retriever_model and args.retriever_data: cmd = [ "python", "scripts/compare_models.py", "--model_type", "retriever", "--finetuned_model", args.retriever_model, "--baseline_model", args.baseline_retriever or "BAAI/bge-m3", "--data_path", args.retriever_data, "--batch_size", str(args.batch_size), "--output_dir", str(self.output_dir / "comparison") ] if args.max_samples: cmd.extend(["--max_samples", str(args.max_samples)]) try: subprocess.run(cmd, check=True, capture_output=True, text=True) logger.info("✅ Retriever comparison completed") # Load results results_file = self.output_dir / "comparison" / "retriever_comparison.json" if results_file.exists(): with open(results_file, 'r') as f: results['retriever'] = json.load(f) except subprocess.CalledProcessError as e: logger.error(f"❌ Retriever comparison failed: {e.stderr}") results['retriever'] = {"status": "failed", "error": str(e)} # Compare reranker if available if args.reranker_model and args.reranker_data: cmd = [ "python", "scripts/compare_models.py", "--model_type", "reranker", "--finetuned_model", args.reranker_model, "--baseline_model", args.baseline_reranker or "BAAI/bge-reranker-base", "--data_path", args.reranker_data, "--batch_size", str(args.batch_size), "--output_dir", str(self.output_dir / "comparison") ] if args.max_samples: cmd.extend(["--max_samples", str(args.max_samples)]) try: subprocess.run(cmd, check=True, capture_output=True, text=True) logger.info("✅ Reranker comparison completed") # Load results results_file = self.output_dir / "comparison" / "reranker_comparison.json" if results_file.exists(): with open(results_file, 'r') as f: results['reranker'] = json.load(f) except subprocess.CalledProcessError as e: logger.error(f"❌ Reranker comparison failed: {e.stderr}") results['reranker'] = {"status": "failed", "error": str(e)} return results def run_benchmark(self, args) -> Dict[str, Any]: """Run performance benchmarking""" logger.info("⚡ Running Performance Benchmarking") results = {} # Benchmark retriever if args.retriever_model: retriever_data = args.retriever_data or self._find_test_data(args.test_data_dir, "m3_test.jsonl") if retriever_data: cmd = [ "python", "scripts/benchmark.py", "--model_type", "retriever", "--model_path", args.retriever_model, "--data_path", retriever_data, "--batch_size", str(args.batch_size), "--output", str(self.output_dir / "retriever_benchmark.txt") ] if args.max_samples: cmd.extend(["--max_samples", str(args.max_samples)]) try: result = subprocess.run(cmd, check=True, capture_output=True, text=True) logger.info("✅ Retriever benchmarking completed") results['retriever'] = {"status": "completed", "output": result.stdout} except subprocess.CalledProcessError as e: logger.error(f"❌ Retriever benchmarking failed: {e.stderr}") results['retriever'] = {"status": "failed", "error": str(e)} # Benchmark reranker if args.reranker_model: reranker_data = args.reranker_data or self._find_test_data(args.test_data_dir, "reranker_test.jsonl") if reranker_data: cmd = [ "python", "scripts/benchmark.py", "--model_type", "reranker", "--model_path", args.reranker_model, "--data_path", reranker_data, "--batch_size", str(args.batch_size), "--output", str(self.output_dir / "reranker_benchmark.txt") ] if args.max_samples: cmd.extend(["--max_samples", str(args.max_samples)]) try: result = subprocess.run(cmd, check=True, capture_output=True, text=True) logger.info("✅ Reranker benchmarking completed") results['reranker'] = {"status": "completed", "output": result.stdout} except subprocess.CalledProcessError as e: logger.error(f"❌ Reranker benchmarking failed: {e.stderr}") results['reranker'] = {"status": "failed", "error": str(e)} return results def run_all_validation(self, args) -> Dict[str, Any]: """Run complete validation suite""" logger.info("🎯 Running Complete Validation Suite") all_results = { 'start_time': datetime.now().isoformat(), 'validation_modes': [] } # 1. Quick validation logger.info("\n" + "="*60) logger.info("PHASE 1: QUICK VALIDATION") logger.info("="*60) all_results['quick'] = self.run_quick_validation(args) all_results['validation_modes'].append('quick') # 2. Comprehensive validation logger.info("\n" + "="*60) logger.info("PHASE 2: COMPREHENSIVE VALIDATION") logger.info("="*60) all_results['comprehensive'] = self.run_comprehensive_validation(args) all_results['validation_modes'].append('comprehensive') # 3. Comparison validation if self._has_comparison_data(args): logger.info("\n" + "="*60) logger.info("PHASE 3: BASELINE COMPARISON") logger.info("="*60) all_results['comparison'] = self.run_comparison(args) all_results['validation_modes'].append('comparison') # 4. Performance benchmarking logger.info("\n" + "="*60) logger.info("PHASE 4: PERFORMANCE BENCHMARKING") logger.info("="*60) all_results['benchmark'] = self.run_benchmark(args) all_results['validation_modes'].append('benchmark') all_results['end_time'] = datetime.now().isoformat() # Generate summary report self._generate_summary_report(all_results) return all_results def _find_test_data(self, test_data_dir: Optional[str], filename: str) -> Optional[str]: """Find test data file in the specified directory""" if not test_data_dir: return None test_path = Path(test_data_dir) / filename return str(test_path) if test_path.exists() else None def _has_comparison_data(self, args) -> bool: """Check if we have data needed for comparison""" has_retriever_data = bool(args.retriever_data or self._find_test_data(args.test_data_dir, "m3_test.jsonl")) has_reranker_data = bool(args.reranker_data or self._find_test_data(args.test_data_dir, "reranker_test.jsonl")) return (args.retriever_model and has_retriever_data) or (args.reranker_model and has_reranker_data) def _generate_summary_report(self, results: Dict[str, Any]): """Generate a comprehensive summary report""" logger.info("\n" + "="*80) logger.info("📋 VALIDATION SUITE SUMMARY REPORT") logger.info("="*80) report_lines = [] report_lines.append("# BGE Model Validation Report") report_lines.append(f"Generated: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}") report_lines.append("") # Summary completed_phases = len(results['validation_modes']) total_phases = 4 success_rate = completed_phases / total_phases * 100 report_lines.append(f"## Overall Summary") report_lines.append(f"- Validation phases completed: {completed_phases}/{total_phases} ({success_rate:.1f}%)") report_lines.append(f"- Start time: {results['start_time']}") report_lines.append(f"- End time: {results['end_time']}") report_lines.append("") # Phase results for phase in ['quick', 'comprehensive', 'comparison', 'benchmark']: if phase in results: phase_result = results[phase] status = "✅ SUCCESS" if phase_result.get('status') != 'failed' else "❌ FAILED" report_lines.append(f"### {phase.title()} Validation: {status}") if phase == 'comparison' and isinstance(phase_result, dict): if 'retriever' in phase_result: retriever_summary = phase_result['retriever'].get('summary', {}) if 'status' in retriever_summary: report_lines.append(f" - Retriever: {retriever_summary['status']}") if 'reranker' in phase_result: reranker_summary = phase_result['reranker'].get('summary', {}) if 'status' in reranker_summary: report_lines.append(f" - Reranker: {reranker_summary['status']}") report_lines.append("") # Recommendations report_lines.append("## Recommendations") if success_rate == 100: report_lines.append("- 🌟 All validation phases completed successfully!") report_lines.append("- Your models are ready for production deployment.") elif success_rate >= 75: report_lines.append("- ✅ Most validation phases successful.") report_lines.append("- Review failed phases and address any issues.") elif success_rate >= 50: report_lines.append("- ⚠️ Partial validation success.") report_lines.append("- Significant issues detected. Review logs carefully.") else: report_lines.append("- ❌ Multiple validation failures.") report_lines.append("- Major issues with model training or configuration.") report_lines.append("") report_lines.append("## Files Generated") report_lines.append(f"- Summary report: {self.output_dir}/validation_summary.md") report_lines.append(f"- Detailed results: {self.output_dir}/validation_results.json") if (self.output_dir / "comparison").exists(): report_lines.append(f"- Comparison results: {self.output_dir}/comparison/") if (self.output_dir / "comprehensive").exists(): report_lines.append(f"- Comprehensive results: {self.output_dir}/comprehensive/") # Save report with open(self.output_dir / "validation_summary.md", 'w', encoding='utf-8') as f: f.write('\n'.join(report_lines)) with open(self.output_dir / "validation_results.json", 'w', encoding='utf-8') as f: json.dump(results, f, indent=2, default=str) # Print summary to console print("\n".join(report_lines)) logger.info(f"📋 Summary report saved to: {self.output_dir}/validation_summary.md") logger.info(f"📋 Detailed results saved to: {self.output_dir}/validation_results.json") def parse_args(): parser = argparse.ArgumentParser( description="BGE Model Validation - Single Entry Point", formatter_class=argparse.RawDescriptionHelpFormatter, epilog=__doc__ ) # Validation mode (required) parser.add_argument( "mode", choices=['quick', 'comprehensive', 'compare', 'benchmark', 'all'], help="Validation mode to run" ) # Model paths parser.add_argument("--retriever_model", type=str, help="Path to fine-tuned retriever model") parser.add_argument("--reranker_model", type=str, help="Path to fine-tuned reranker model") # Data paths parser.add_argument("--test_data_dir", type=str, help="Directory containing test datasets") parser.add_argument("--retriever_data", type=str, help="Path to retriever test data") parser.add_argument("--reranker_data", type=str, help="Path to reranker test data") # Baseline models for comparison parser.add_argument("--baseline_retriever", type=str, default="BAAI/bge-m3", help="Baseline retriever model") parser.add_argument("--baseline_reranker", type=str, default="BAAI/bge-reranker-base", help="Baseline reranker model") # Configuration parser.add_argument("--batch_size", type=int, default=32, help="Batch size for evaluation") parser.add_argument("--max_samples", type=int, help="Maximum samples to test (for speed)") parser.add_argument("--output_dir", type=str, default="./validation_results", help="Directory to save results") return parser.parse_args() def main(): args = parse_args() # Validate arguments if not args.retriever_model and not args.reranker_model: logger.error("❌ At least one model (--retriever_model or --reranker_model) is required") return 1 logger.info("🚀 BGE Model Validation Started") logger.info(f"Mode: {args.mode}") logger.info(f"Output directory: {args.output_dir}") # Initialize orchestrator orchestrator = ValidationOrchestrator(args.output_dir) # Run validation based on mode try: if args.mode == 'quick': results = orchestrator.run_quick_validation(args) elif args.mode == 'comprehensive': results = orchestrator.run_comprehensive_validation(args) elif args.mode == 'compare': results = orchestrator.run_comparison(args) elif args.mode == 'benchmark': results = orchestrator.run_benchmark(args) elif args.mode == 'all': results = orchestrator.run_all_validation(args) else: logger.error(f"Unknown validation mode: {args.mode}") return 1 logger.info("✅ Validation completed successfully!") logger.info(f"📁 Results saved to: {args.output_dir}") return 0 except Exception as e: logger.error(f"❌ Validation failed: {e}") return 1 if __name__ == "__main__": exit(main())