bge_finetune/scripts/run_validation_suite.py
2025-07-23 14:54:46 +08:00

698 lines
28 KiB
Python

#!/usr/bin/env python3
"""
BGE Validation Suite Runner
This script runs a complete validation suite for BGE fine-tuned models,
including multiple validation approaches and generating comprehensive reports.
Usage:
# Auto-discover models and run full suite
python scripts/run_validation_suite.py --auto-discover
# Run with specific models
python scripts/run_validation_suite.py \\
--retriever_model ./output/bge-m3-enhanced/final_model \\
--reranker_model ./output/bge-reranker/final_model
# Run only specific validation types
python scripts/run_validation_suite.py \\
--retriever_model ./output/bge-m3-enhanced/final_model \\
--validation_types quick comprehensive comparison
"""
import os
import sys
import json
import argparse
import subprocess
import logging
from pathlib import Path
from datetime import datetime
from typing import List, Dict, Optional, Any
# Add project root to path
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from scripts.validation_utils import ModelDiscovery, TestDataManager, ValidationReportAnalyzer
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
class ValidationSuiteRunner:
"""Orchestrates comprehensive validation of BGE models"""
def __init__(self, config: Dict[str, Any]):
self.config = config
self.results_dir = Path(config.get('output_dir', './validation_suite_results'))
self.results_dir.mkdir(parents=True, exist_ok=True)
# Initialize components
self.model_discovery = ModelDiscovery(config.get('workspace_root', '.'))
self.data_manager = TestDataManager(config.get('data_dir', 'data/datasets'))
# Suite results
self.suite_results = {
'start_time': datetime.now().isoformat(),
'config': config,
'validation_results': {},
'summary': {},
'issues': [],
'recommendations': []
}
def run_suite(self) -> Dict[str, Any]:
"""Run the complete validation suite"""
logger.info("🚀 Starting BGE Validation Suite")
logger.info("=" * 60)
try:
# Phase 1: Discovery and Setup
self._phase_discovery()
# Phase 2: Quick Validation
if 'quick' in self.config.get('validation_types', ['quick']):
self._phase_quick_validation()
# Phase 3: Comprehensive Validation
if 'comprehensive' in self.config.get('validation_types', ['comprehensive']):
self._phase_comprehensive_validation()
# Phase 4: Comparison Benchmarks
if 'comparison' in self.config.get('validation_types', ['comparison']):
self._phase_comparison_benchmarks()
# Phase 5: Analysis and Reporting
self._phase_analysis()
# Phase 6: Generate Final Report
self._generate_final_report()
self.suite_results['end_time'] = datetime.now().isoformat()
self.suite_results['status'] = 'completed'
logger.info("✅ Validation Suite Completed Successfully!")
except Exception as e:
logger.error(f"❌ Validation Suite Failed: {e}")
self.suite_results['status'] = 'failed'
self.suite_results['error'] = str(e)
raise
return self.suite_results
def _phase_discovery(self):
"""Phase 1: Discover models and datasets"""
logger.info("📡 Phase 1: Discovery and Setup")
# Auto-discover models if requested
if self.config.get('auto_discover', False):
discovered_models = self.model_discovery.find_trained_models()
# Set models from discovery if not provided
if not self.config.get('retriever_model'):
if discovered_models['retrievers']:
model_name = list(discovered_models['retrievers'].keys())[0]
self.config['retriever_model'] = discovered_models['retrievers'][model_name]['path']
logger.info(f" 📊 Auto-discovered retriever: {model_name}")
if not self.config.get('reranker_model'):
if discovered_models['rerankers']:
model_name = list(discovered_models['rerankers'].keys())[0]
self.config['reranker_model'] = discovered_models['rerankers'][model_name]['path']
logger.info(f" 🏆 Auto-discovered reranker: {model_name}")
self.suite_results['discovered_models'] = discovered_models
# Discover test datasets
discovered_datasets = self.data_manager.find_test_datasets()
self.suite_results['available_datasets'] = discovered_datasets
# Validate models exist
issues = []
if self.config.get('retriever_model'):
if not os.path.exists(self.config['retriever_model']):
issues.append(f"Retriever model not found: {self.config['retriever_model']}")
if self.config.get('reranker_model'):
if not os.path.exists(self.config['reranker_model']):
issues.append(f"Reranker model not found: {self.config['reranker_model']}")
if issues:
for issue in issues:
logger.error(f"{issue}")
self.suite_results['issues'].extend(issues)
return
logger.info(" ✅ Discovery phase completed")
def _phase_quick_validation(self):
"""Phase 2: Quick validation tests"""
logger.info("⚡ Phase 2: Quick Validation")
try:
# Run quick validation script
cmd = ['python', 'scripts/quick_validation.py']
if self.config.get('retriever_model'):
cmd.extend(['--retriever_model', self.config['retriever_model']])
if self.config.get('reranker_model'):
cmd.extend(['--reranker_model', self.config['reranker_model']])
if self.config.get('retriever_model') and self.config.get('reranker_model'):
cmd.append('--test_pipeline')
# Save results to file
quick_results_file = self.results_dir / 'quick_validation_results.json'
cmd.extend(['--output_file', str(quick_results_file)])
logger.info(f" Running: {' '.join(cmd)}")
result = subprocess.run(cmd, capture_output=True, text=True)
if result.returncode == 0:
logger.info(" ✅ Quick validation completed")
# Load results
if quick_results_file.exists():
with open(quick_results_file, 'r') as f:
quick_results = json.load(f)
self.suite_results['validation_results']['quick'] = quick_results
else:
logger.error(f" ❌ Quick validation failed: {result.stderr}")
self.suite_results['issues'].append(f"Quick validation failed: {result.stderr}")
except Exception as e:
logger.error(f" ❌ Error in quick validation: {e}")
self.suite_results['issues'].append(f"Quick validation error: {e}")
def _phase_comprehensive_validation(self):
"""Phase 3: Comprehensive validation"""
logger.info("🔬 Phase 3: Comprehensive Validation")
try:
# Run comprehensive validation script
cmd = ['python', 'scripts/comprehensive_validation.py']
if self.config.get('retriever_model'):
cmd.extend(['--retriever_finetuned', self.config['retriever_model']])
if self.config.get('reranker_model'):
cmd.extend(['--reranker_finetuned', self.config['reranker_model']])
# Use specific datasets if provided
if self.config.get('test_datasets'):
cmd.extend(['--test_datasets'] + self.config['test_datasets'])
# Set output directory
comprehensive_dir = self.results_dir / 'comprehensive'
cmd.extend(['--output_dir', str(comprehensive_dir)])
logger.info(f" Running: {' '.join(cmd)}")
result = subprocess.run(cmd, capture_output=True, text=True)
if result.returncode == 0:
logger.info(" ✅ Comprehensive validation completed")
# Load results
results_file = comprehensive_dir / 'validation_results.json'
if results_file.exists():
with open(results_file, 'r') as f:
comp_results = json.load(f)
self.suite_results['validation_results']['comprehensive'] = comp_results
else:
logger.error(f" ❌ Comprehensive validation failed: {result.stderr}")
self.suite_results['issues'].append(f"Comprehensive validation failed: {result.stderr}")
except Exception as e:
logger.error(f" ❌ Error in comprehensive validation: {e}")
self.suite_results['issues'].append(f"Comprehensive validation error: {e}")
def _phase_comparison_benchmarks(self):
"""Phase 4: Run comparison benchmarks"""
logger.info("📊 Phase 4: Comparison Benchmarks")
# Find suitable datasets for comparison
datasets = self.data_manager.find_test_datasets()
try:
# Run retriever comparison if model exists
if self.config.get('retriever_model'):
self._run_retriever_comparison(datasets)
# Run reranker comparison if model exists
if self.config.get('reranker_model'):
self._run_reranker_comparison(datasets)
except Exception as e:
logger.error(f" ❌ Error in comparison benchmarks: {e}")
self.suite_results['issues'].append(f"Comparison benchmarks error: {e}")
def _run_retriever_comparison(self, datasets: Dict[str, List[str]]):
"""Run retriever comparison benchmarks"""
logger.info(" 🔍 Running retriever comparison...")
# Find suitable datasets for retriever testing
test_datasets = datasets.get('retriever_datasets', []) + datasets.get('general_datasets', [])
if not test_datasets:
logger.warning(" ⚠️ No suitable datasets found for retriever comparison")
return
for dataset_path in test_datasets[:2]: # Limit to 2 datasets for speed
if not os.path.exists(dataset_path):
continue
try:
cmd = [
'python', 'scripts/compare_models.py',
'--model_type', 'retriever',
'--finetuned_model_path', self.config['retriever_model'],
'--baseline_model_path', self.config.get('retriever_baseline', 'BAAI/bge-m3'),
'--data_path', dataset_path,
'--batch_size', str(self.config.get('batch_size', 16)),
'--max_samples', str(self.config.get('max_samples', 500)),
'--output', str(self.results_dir / f'retriever_comparison_{Path(dataset_path).stem}.txt')
]
result = subprocess.run(cmd, capture_output=True, text=True)
if result.returncode == 0:
logger.info(f" ✅ Retriever comparison on {Path(dataset_path).name}")
else:
logger.warning(f" ⚠️ Retriever comparison failed on {Path(dataset_path).name}")
except Exception as e:
logger.warning(f" ⚠️ Error comparing retriever on {dataset_path}: {e}")
def _run_reranker_comparison(self, datasets: Dict[str, List[str]]):
"""Run reranker comparison benchmarks"""
logger.info(" 🏆 Running reranker comparison...")
# Find suitable datasets for reranker testing
test_datasets = datasets.get('reranker_datasets', []) + datasets.get('general_datasets', [])
if not test_datasets:
logger.warning(" ⚠️ No suitable datasets found for reranker comparison")
return
for dataset_path in test_datasets[:2]: # Limit to 2 datasets for speed
if not os.path.exists(dataset_path):
continue
try:
cmd = [
'python', 'scripts/compare_models.py',
'--model_type', 'reranker',
'--finetuned_model_path', self.config['reranker_model'],
'--baseline_model_path', self.config.get('reranker_baseline', 'BAAI/bge-reranker-base'),
'--data_path', dataset_path,
'--batch_size', str(self.config.get('batch_size', 16)),
'--max_samples', str(self.config.get('max_samples', 500)),
'--output', str(self.results_dir / f'reranker_comparison_{Path(dataset_path).stem}.txt')
]
result = subprocess.run(cmd, capture_output=True, text=True)
if result.returncode == 0:
logger.info(f" ✅ Reranker comparison on {Path(dataset_path).name}")
else:
logger.warning(f" ⚠️ Reranker comparison failed on {Path(dataset_path).name}")
except Exception as e:
logger.warning(f" ⚠️ Error comparing reranker on {dataset_path}: {e}")
def _phase_analysis(self):
"""Phase 5: Analyze all results"""
logger.info("📈 Phase 5: Results Analysis")
try:
# Analyze comprehensive results if available
comprehensive_dir = self.results_dir / 'comprehensive'
if comprehensive_dir.exists():
analyzer = ValidationReportAnalyzer(str(comprehensive_dir))
analysis = analyzer.analyze_results()
self.suite_results['analysis'] = analysis
logger.info(" ✅ Results analysis completed")
else:
logger.warning(" ⚠️ No comprehensive results found for analysis")
except Exception as e:
logger.error(f" ❌ Error in results analysis: {e}")
self.suite_results['issues'].append(f"Analysis error: {e}")
def _generate_final_report(self):
"""Generate final comprehensive report"""
logger.info("📝 Phase 6: Generating Final Report")
try:
# Create summary
self._create_suite_summary()
# Save suite results
suite_results_file = self.results_dir / 'validation_suite_results.json'
with open(suite_results_file, 'w', encoding='utf-8') as f:
json.dump(self.suite_results, f, indent=2, ensure_ascii=False)
# Generate HTML report
self._generate_html_report()
logger.info(f" 📊 Reports saved to: {self.results_dir}")
except Exception as e:
logger.error(f" ❌ Error generating final report: {e}")
self.suite_results['issues'].append(f"Report generation error: {e}")
def _create_suite_summary(self):
"""Create suite summary"""
summary = {
'total_validations': 0,
'successful_validations': 0,
'failed_validations': 0,
'overall_verdict': 'unknown',
'key_findings': [],
'critical_issues': []
}
# Count validations
for validation_type, results in self.suite_results.get('validation_results', {}).items():
summary['total_validations'] += 1
if results:
summary['successful_validations'] += 1
else:
summary['failed_validations'] += 1
# Determine overall verdict
if summary['successful_validations'] > 0 and summary['failed_validations'] == 0:
if self.suite_results.get('analysis', {}).get('overall_status') == 'excellent':
summary['overall_verdict'] = 'excellent'
elif self.suite_results.get('analysis', {}).get('overall_status') in ['good', 'mixed']:
summary['overall_verdict'] = 'good'
else:
summary['overall_verdict'] = 'fair'
elif summary['successful_validations'] > summary['failed_validations']:
summary['overall_verdict'] = 'partial'
else:
summary['overall_verdict'] = 'poor'
# Extract key findings from analysis
if 'analysis' in self.suite_results:
analysis = self.suite_results['analysis']
if 'recommendations' in analysis:
summary['key_findings'] = analysis['recommendations'][:5] # Top 5 recommendations
# Extract critical issues
summary['critical_issues'] = self.suite_results.get('issues', [])
self.suite_results['summary'] = summary
def _generate_html_report(self):
"""Generate HTML summary report"""
html_file = self.results_dir / 'validation_suite_report.html'
summary = self.suite_results.get('summary', {})
verdict_colors = {
'excellent': '#28a745',
'good': '#17a2b8',
'fair': '#ffc107',
'partial': '#fd7e14',
'poor': '#dc3545',
'unknown': '#6c757d'
}
verdict_color = verdict_colors.get(summary.get('overall_verdict', 'unknown'), '#6c757d')
html_content = f"""
<!DOCTYPE html>
<html>
<head>
<title>BGE Validation Suite Report</title>
<style>
body {{ font-family: Arial, sans-serif; margin: 20px; line-height: 1.6; }}
.header {{ background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
color: white; padding: 30px; border-radius: 10px; text-align: center; }}
.verdict {{ font-size: 24px; font-weight: bold; color: {verdict_color};
margin: 20px 0; text-align: center; padding: 15px;
border: 3px solid {verdict_color}; border-radius: 10px; }}
.section {{ margin: 30px 0; }}
.metrics {{ display: flex; justify-content: space-around; margin: 20px 0; }}
.metric {{ text-align: center; padding: 15px; background: #f8f9fa;
border-radius: 8px; min-width: 120px; }}
.metric-value {{ font-size: 2em; font-weight: bold; color: #495057; }}
.metric-label {{ color: #6c757d; }}
.findings {{ background: #e3f2fd; padding: 20px; border-radius: 8px;
border-left: 4px solid #2196f3; }}
.issues {{ background: #ffebee; padding: 20px; border-radius: 8px;
border-left: 4px solid #f44336; }}
.timeline {{ margin: 20px 0; }}
.timeline-item {{ margin: 10px 0; padding: 10px; background: #f8f9fa;
border-left: 3px solid #007bff; border-radius: 0 5px 5px 0; }}
ul {{ padding-left: 20px; }}
li {{ margin: 8px 0; }}
</style>
</head>
<body>
<div class="header">
<h1>🚀 BGE Validation Suite Report</h1>
<p>Comprehensive Model Performance Analysis</p>
<p>Generated: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}</p>
</div>
<div class="verdict">
Overall Verdict: {summary.get('overall_verdict', 'UNKNOWN').upper()}
</div>
<div class="metrics">
<div class="metric">
<div class="metric-value">{summary.get('total_validations', 0)}</div>
<div class="metric-label">Total Validations</div>
</div>
<div class="metric">
<div class="metric-value">{summary.get('successful_validations', 0)}</div>
<div class="metric-label">Successful</div>
</div>
<div class="metric">
<div class="metric-value">{summary.get('failed_validations', 0)}</div>
<div class="metric-label">Failed</div>
</div>
</div>
"""
# Add key findings
if summary.get('key_findings'):
html_content += """
<div class="section">
<h2>🔍 Key Findings</h2>
<div class="findings">
<ul>
"""
for finding in summary['key_findings']:
html_content += f" <li>{finding}</li>\n"
html_content += """
</ul>
</div>
</div>
"""
# Add critical issues if any
if summary.get('critical_issues'):
html_content += """
<div class="section">
<h2>⚠️ Critical Issues</h2>
<div class="issues">
<ul>
"""
for issue in summary['critical_issues']:
html_content += f" <li>{issue}</li>\n"
html_content += """
</ul>
</div>
</div>
"""
# Add validation timeline
html_content += f"""
<div class="section">
<h2>⏱️ Validation Timeline</h2>
<div class="timeline">
<div class="timeline-item">
<strong>Suite Started:</strong> {self.suite_results.get('start_time', 'Unknown')}
</div>
"""
for validation_type in self.suite_results.get('validation_results', {}).keys():
html_content += f"""
<div class="timeline-item">
<strong>{validation_type.title()} Validation:</strong> Completed
</div>
"""
html_content += f"""
<div class="timeline-item">
<strong>Suite Completed:</strong> {self.suite_results.get('end_time', 'In Progress')}
</div>
</div>
</div>
<div class="section">
<h2>📁 Detailed Results</h2>
<p>For detailed validation results, check the following files in <code>{self.results_dir}</code>:</p>
<ul>
<li><code>validation_suite_results.json</code> - Complete suite results</li>
<li><code>quick_validation_results.json</code> - Quick validation results</li>
<li><code>comprehensive/</code> - Comprehensive validation reports</li>
<li><code>*_comparison_*.txt</code> - Model comparison results</li>
</ul>
</div>
</body>
</html>
"""
with open(html_file, 'w', encoding='utf-8') as f:
f.write(html_content)
logger.info(f" 📊 HTML report: {html_file}")
def parse_args():
"""Parse command line arguments"""
parser = argparse.ArgumentParser(
description="Run comprehensive BGE validation suite",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Examples:
# Auto-discover models and run full suite
python scripts/run_validation_suite.py --auto-discover
# Run with specific models
python scripts/run_validation_suite.py \\
--retriever_model ./output/bge-m3-enhanced/final_model \\
--reranker_model ./output/bge-reranker/final_model
# Run only specific validation types
python scripts/run_validation_suite.py \\
--retriever_model ./output/bge-m3-enhanced/final_model \\
--validation_types quick comprehensive
"""
)
# Model paths
parser.add_argument("--retriever_model", type=str, default=None,
help="Path to fine-tuned retriever model")
parser.add_argument("--reranker_model", type=str, default=None,
help="Path to fine-tuned reranker model")
parser.add_argument("--auto_discover", action="store_true",
help="Auto-discover trained models in workspace")
# Baseline models
parser.add_argument("--retriever_baseline", type=str, default="BAAI/bge-m3",
help="Baseline retriever model")
parser.add_argument("--reranker_baseline", type=str, default="BAAI/bge-reranker-base",
help="Baseline reranker model")
# Validation configuration
parser.add_argument("--validation_types", type=str, nargs="+",
default=["quick", "comprehensive", "comparison"],
choices=["quick", "comprehensive", "comparison"],
help="Types of validation to run")
parser.add_argument("--test_datasets", type=str, nargs="+", default=None,
help="Specific test datasets to use")
# Performance settings
parser.add_argument("--batch_size", type=int, default=16,
help="Batch size for evaluation")
parser.add_argument("--max_samples", type=int, default=1000,
help="Maximum samples per dataset for speed")
# Directories
parser.add_argument("--workspace_root", type=str, default=".",
help="Workspace root directory")
parser.add_argument("--data_dir", type=str, default="data/datasets",
help="Test datasets directory")
parser.add_argument("--output_dir", type=str, default="./validation_suite_results",
help="Output directory for all results")
return parser.parse_args()
def main():
"""Main function"""
args = parse_args()
# Validate input
if not args.auto_discover and not args.retriever_model and not args.reranker_model:
print("❌ Error: Must specify models or use --auto_discover")
print(" Use --retriever_model, --reranker_model, or --auto_discover")
return 1
# Create configuration
config = {
'retriever_model': args.retriever_model,
'reranker_model': args.reranker_model,
'auto_discover': args.auto_discover,
'retriever_baseline': args.retriever_baseline,
'reranker_baseline': args.reranker_baseline,
'validation_types': args.validation_types,
'test_datasets': args.test_datasets,
'batch_size': args.batch_size,
'max_samples': args.max_samples,
'workspace_root': args.workspace_root,
'data_dir': args.data_dir,
'output_dir': args.output_dir
}
# Run validation suite
try:
runner = ValidationSuiteRunner(config)
results = runner.run_suite()
# Print final summary
summary = results.get('summary', {})
verdict = summary.get('overall_verdict', 'unknown')
print("\n" + "=" * 80)
print("🎯 VALIDATION SUITE SUMMARY")
print("=" * 80)
verdict_emojis = {
'excellent': '🌟',
'good': '',
'fair': '👌',
'partial': '⚠️',
'poor': '',
'unknown': ''
}
print(f"{verdict_emojis.get(verdict, '')} Overall Verdict: {verdict.upper()}")
print(f"📊 Validations: {summary.get('successful_validations', 0)}/{summary.get('total_validations', 0)} successful")
if summary.get('key_findings'):
print(f"\n🔍 Key Findings:")
for i, finding in enumerate(summary['key_findings'][:3], 1):
print(f" {i}. {finding}")
if summary.get('critical_issues'):
print(f"\n⚠️ Critical Issues:")
for issue in summary['critical_issues'][:3]:
print(f"{issue}")
print(f"\n📁 Detailed results: {config['output_dir']}")
print("🌐 Open validation_suite_report.html in your browser for full report")
return 0 if verdict in ['excellent', 'good'] else 1
except KeyboardInterrupt:
print("\n❌ Validation suite interrupted by user")
return 1
except Exception as e:
print(f"\n❌ Validation suite failed: {e}")
return 1
if __name__ == "__main__":
exit(main())