init commit
This commit is contained in:
641
scripts/validation_utils.py
Normal file
641
scripts/validation_utils.py
Normal file
@@ -0,0 +1,641 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Validation Utilities for BGE Fine-tuning
|
||||
|
||||
This module provides utilities for model validation including:
|
||||
- Automatic model discovery
|
||||
- Test data preparation and validation
|
||||
- Results analysis and reporting
|
||||
- Benchmark dataset creation
|
||||
|
||||
Usage:
|
||||
python scripts/validation_utils.py --discover-models
|
||||
python scripts/validation_utils.py --prepare-test-data
|
||||
python scripts/validation_utils.py --analyze-results ./validation_results
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
import glob
|
||||
import argparse
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import List, Dict, Tuple, Optional, Any
|
||||
from datetime import datetime
|
||||
import pandas as pd
|
||||
|
||||
# Add project root to path
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s')
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ModelDiscovery:
|
||||
"""Discover trained models in the workspace"""
|
||||
|
||||
def __init__(self, workspace_root: str = "."):
|
||||
self.workspace_root = Path(workspace_root)
|
||||
|
||||
def find_trained_models(self) -> Dict[str, Dict[str, str]]:
|
||||
"""Find all trained BGE models in the workspace"""
|
||||
models = {
|
||||
'retrievers': {},
|
||||
'rerankers': {},
|
||||
'joint': {}
|
||||
}
|
||||
|
||||
# Common output directories to search
|
||||
search_patterns = [
|
||||
"output/*/final_model",
|
||||
"output/*/best_model",
|
||||
"output/*/*/final_model",
|
||||
"output/*/*/best_model",
|
||||
"test_*_output/final_model",
|
||||
"training_output/*/final_model"
|
||||
]
|
||||
|
||||
for pattern in search_patterns:
|
||||
for model_path in glob.glob(str(self.workspace_root / pattern)):
|
||||
model_path = Path(model_path)
|
||||
|
||||
if self._is_valid_model_directory(model_path):
|
||||
model_info = self._analyze_model_directory(model_path)
|
||||
|
||||
if model_info['type'] == 'retriever':
|
||||
models['retrievers'][model_info['name']] = model_info
|
||||
elif model_info['type'] == 'reranker':
|
||||
models['rerankers'][model_info['name']] = model_info
|
||||
elif model_info['type'] == 'joint':
|
||||
models['joint'][model_info['name']] = model_info
|
||||
|
||||
return models
|
||||
|
||||
def _is_valid_model_directory(self, path: Path) -> bool:
|
||||
"""Check if directory contains a valid trained model"""
|
||||
required_files = ['config.json', 'pytorch_model.bin']
|
||||
alternative_files = ['model.safetensors', 'tokenizer.json']
|
||||
|
||||
has_required = any((path / f).exists() for f in required_files)
|
||||
has_alternative = any((path / f).exists() for f in alternative_files)
|
||||
|
||||
return has_required or has_alternative
|
||||
|
||||
def _analyze_model_directory(self, path: Path) -> Dict[str, str]:
|
||||
"""Analyze model directory to determine type and metadata"""
|
||||
info = {
|
||||
'path': str(path),
|
||||
'name': self._generate_model_name(path),
|
||||
'type': 'unknown',
|
||||
'created': 'unknown',
|
||||
'training_info': {}
|
||||
}
|
||||
|
||||
# Try to determine model type from path
|
||||
path_str = str(path).lower()
|
||||
if 'm3' in path_str or 'retriever' in path_str:
|
||||
info['type'] = 'retriever'
|
||||
elif 'reranker' in path_str:
|
||||
info['type'] = 'reranker'
|
||||
elif 'joint' in path_str:
|
||||
info['type'] = 'joint'
|
||||
|
||||
# Get creation time
|
||||
if path.exists():
|
||||
info['created'] = datetime.fromtimestamp(path.stat().st_mtime).strftime('%Y-%m-%d %H:%M:%S')
|
||||
|
||||
# Try to read training info
|
||||
training_summary_path = path.parent / 'training_summary.json'
|
||||
if training_summary_path.exists():
|
||||
try:
|
||||
with open(training_summary_path, 'r') as f:
|
||||
info['training_info'] = json.load(f)
|
||||
except:
|
||||
pass
|
||||
|
||||
# Try to read config
|
||||
config_path = path / 'config.json'
|
||||
if config_path.exists():
|
||||
try:
|
||||
with open(config_path, 'r') as f:
|
||||
config = json.load(f)
|
||||
if 'model_type' in config:
|
||||
info['type'] = config['model_type']
|
||||
except:
|
||||
pass
|
||||
|
||||
return info
|
||||
|
||||
def _generate_model_name(self, path: Path) -> str:
|
||||
"""Generate a readable name for the model"""
|
||||
parts = path.parts
|
||||
|
||||
# Try to find meaningful parts
|
||||
meaningful_parts = []
|
||||
for part in parts[-3:]: # Look at last 3 parts
|
||||
if part not in ['final_model', 'best_model', 'output']:
|
||||
meaningful_parts.append(part)
|
||||
|
||||
if meaningful_parts:
|
||||
return '_'.join(meaningful_parts)
|
||||
else:
|
||||
return f"model_{path.parent.name}"
|
||||
|
||||
def print_discovered_models(self):
|
||||
"""Print discovered models in a formatted way"""
|
||||
models = self.find_trained_models()
|
||||
|
||||
print("\n🔍 DISCOVERED TRAINED MODELS")
|
||||
print("=" * 60)
|
||||
|
||||
for model_type, model_dict in models.items():
|
||||
if model_dict:
|
||||
print(f"\n📊 {model_type.upper()}:")
|
||||
for name, info in model_dict.items():
|
||||
print(f" • {name}")
|
||||
print(f" Path: {info['path']}")
|
||||
print(f" Created: {info['created']}")
|
||||
if info['training_info']:
|
||||
training_info = info['training_info']
|
||||
if 'total_steps' in training_info:
|
||||
print(f" Training Steps: {training_info['total_steps']}")
|
||||
if 'final_metrics' in training_info and training_info['final_metrics']:
|
||||
print(f" Final Metrics: {training_info['final_metrics']}")
|
||||
|
||||
if not any(models.values()):
|
||||
print("❌ No trained models found!")
|
||||
print("💡 Make sure you have trained some models first.")
|
||||
|
||||
return models
|
||||
|
||||
|
||||
class TestDataManager:
|
||||
"""Manage test datasets for validation"""
|
||||
|
||||
def __init__(self, data_dir: str = "data/datasets"):
|
||||
self.data_dir = Path(data_dir)
|
||||
|
||||
def find_test_datasets(self) -> Dict[str, List[str]]:
|
||||
"""Find available test datasets"""
|
||||
datasets = {
|
||||
'retriever_datasets': [],
|
||||
'reranker_datasets': [],
|
||||
'general_datasets': []
|
||||
}
|
||||
|
||||
# Search for JSONL files in data directory
|
||||
for jsonl_file in self.data_dir.rglob("*.jsonl"):
|
||||
file_path = str(jsonl_file)
|
||||
|
||||
# Categorize based on filename/path
|
||||
if 'reranker' in file_path.lower():
|
||||
datasets['reranker_datasets'].append(file_path)
|
||||
elif 'embedding' in file_path.lower() or 'm3' in file_path.lower():
|
||||
datasets['retriever_datasets'].append(file_path)
|
||||
else:
|
||||
datasets['general_datasets'].append(file_path)
|
||||
|
||||
# Also search for JSON files
|
||||
for json_file in self.data_dir.rglob("*.json"):
|
||||
file_path = str(json_file)
|
||||
if 'test' in file_path.lower() or 'dev' in file_path.lower():
|
||||
if 'reranker' in file_path.lower() or 'AFQMC' in file_path:
|
||||
datasets['reranker_datasets'].append(file_path)
|
||||
else:
|
||||
datasets['general_datasets'].append(file_path)
|
||||
|
||||
return datasets
|
||||
|
||||
def validate_dataset(self, dataset_path: str) -> Dict[str, Any]:
|
||||
"""Validate a dataset file and analyze its format"""
|
||||
validation_result = {
|
||||
'path': dataset_path,
|
||||
'exists': False,
|
||||
'format': 'unknown',
|
||||
'sample_count': 0,
|
||||
'has_queries': False,
|
||||
'has_passages': False,
|
||||
'has_labels': False,
|
||||
'issues': [],
|
||||
'sample_data': {}
|
||||
}
|
||||
|
||||
if not os.path.exists(dataset_path):
|
||||
validation_result['issues'].append(f"File does not exist: {dataset_path}")
|
||||
return validation_result
|
||||
|
||||
validation_result['exists'] = True
|
||||
|
||||
try:
|
||||
# Read first few lines to analyze format
|
||||
samples = []
|
||||
with open(dataset_path, 'r', encoding='utf-8') as f:
|
||||
for i, line in enumerate(f):
|
||||
if i >= 5: # Only analyze first 5 lines
|
||||
break
|
||||
try:
|
||||
sample = json.loads(line.strip())
|
||||
samples.append(sample)
|
||||
except json.JSONDecodeError:
|
||||
validation_result['issues'].append(f"Invalid JSON at line {i+1}")
|
||||
|
||||
if samples:
|
||||
validation_result['sample_count'] = len(samples)
|
||||
validation_result['sample_data'] = samples[0] if samples else {}
|
||||
|
||||
# Analyze format
|
||||
first_sample = samples[0]
|
||||
|
||||
# Check for common fields
|
||||
if 'query' in first_sample:
|
||||
validation_result['has_queries'] = True
|
||||
|
||||
if any(field in first_sample for field in ['passage', 'pos', 'neg', 'candidates']):
|
||||
validation_result['has_passages'] = True
|
||||
|
||||
if any(field in first_sample for field in ['label', 'labels', 'score']):
|
||||
validation_result['has_labels'] = True
|
||||
|
||||
# Determine format
|
||||
if 'pos' in first_sample and 'neg' in first_sample:
|
||||
validation_result['format'] = 'triplets'
|
||||
elif 'candidates' in first_sample:
|
||||
validation_result['format'] = 'candidates'
|
||||
elif 'passage' in first_sample and 'label' in first_sample:
|
||||
validation_result['format'] = 'pairs'
|
||||
elif 'sentence1' in first_sample and 'sentence2' in first_sample:
|
||||
validation_result['format'] = 'sentence_pairs'
|
||||
else:
|
||||
validation_result['format'] = 'custom'
|
||||
|
||||
except Exception as e:
|
||||
validation_result['issues'].append(f"Error reading file: {e}")
|
||||
|
||||
return validation_result
|
||||
|
||||
def print_dataset_analysis(self):
|
||||
"""Print analysis of available datasets"""
|
||||
datasets = self.find_test_datasets()
|
||||
|
||||
print("\n📁 AVAILABLE TEST DATASETS")
|
||||
print("=" * 60)
|
||||
|
||||
for category, dataset_list in datasets.items():
|
||||
if dataset_list:
|
||||
print(f"\n📊 {category.upper().replace('_', ' ')}:")
|
||||
for dataset_path in dataset_list:
|
||||
validation = self.validate_dataset(dataset_path)
|
||||
status = "✅" if validation['exists'] and not validation['issues'] else "⚠️"
|
||||
print(f" {status} {Path(dataset_path).name}")
|
||||
print(f" Path: {dataset_path}")
|
||||
print(f" Format: {validation['format']}")
|
||||
if validation['issues']:
|
||||
for issue in validation['issues']:
|
||||
print(f" ⚠️ {issue}")
|
||||
|
||||
if not any(datasets.values()):
|
||||
print("❌ No test datasets found!")
|
||||
print("💡 Place your test data in the data/datasets/ directory")
|
||||
|
||||
return datasets
|
||||
|
||||
def create_sample_test_data(self, output_dir: str = "data/test_samples"):
|
||||
"""Create sample test datasets for validation"""
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
# Create sample retriever test data
|
||||
retriever_test = [
|
||||
{
|
||||
"query": "什么是人工智能?",
|
||||
"pos": [
|
||||
"人工智能是计算机科学的一个分支,致力于创建能够模拟人类智能的系统",
|
||||
"AI技术包括机器学习、深度学习、自然语言处理等多个领域"
|
||||
],
|
||||
"neg": [
|
||||
"今天的天气非常好,阳光明媚",
|
||||
"我喜欢听音乐和看电影"
|
||||
]
|
||||
},
|
||||
{
|
||||
"query": "如何学习编程?",
|
||||
"pos": [
|
||||
"学习编程需要选择合适的语言,如Python、Java等,然后通过实践项目提升技能",
|
||||
"编程学习的关键是多写代码、多调试、多思考算法逻辑"
|
||||
],
|
||||
"neg": [
|
||||
"烹饪是一门艺术,需要掌握火候和调味",
|
||||
"运动对健康很重要,每天应该锻炼30分钟"
|
||||
]
|
||||
}
|
||||
]
|
||||
|
||||
# Create sample reranker test data
|
||||
reranker_test = [
|
||||
{
|
||||
"query": "什么是人工智能?",
|
||||
"passage": "人工智能是计算机科学的一个分支,致力于创建能够模拟人类智能的系统",
|
||||
"label": 1
|
||||
},
|
||||
{
|
||||
"query": "什么是人工智能?",
|
||||
"passage": "今天的天气非常好,阳光明媚",
|
||||
"label": 0
|
||||
},
|
||||
{
|
||||
"query": "如何学习编程?",
|
||||
"passage": "学习编程需要选择合适的语言,如Python、Java等,然后通过实践项目提升技能",
|
||||
"label": 1
|
||||
},
|
||||
{
|
||||
"query": "如何学习编程?",
|
||||
"passage": "烹饪是一门艺术,需要掌握火候和调味",
|
||||
"label": 0
|
||||
}
|
||||
]
|
||||
|
||||
# Save retriever test data
|
||||
retriever_path = os.path.join(output_dir, "retriever_test.jsonl")
|
||||
with open(retriever_path, 'w', encoding='utf-8') as f:
|
||||
for item in retriever_test:
|
||||
f.write(json.dumps(item, ensure_ascii=False) + '\n')
|
||||
|
||||
# Save reranker test data
|
||||
reranker_path = os.path.join(output_dir, "reranker_test.jsonl")
|
||||
with open(reranker_path, 'w', encoding='utf-8') as f:
|
||||
for item in reranker_test:
|
||||
f.write(json.dumps(item, ensure_ascii=False) + '\n')
|
||||
|
||||
print(f"✅ Sample test datasets created in: {output_dir}")
|
||||
print(f" 📁 Retriever test: {retriever_path}")
|
||||
print(f" 📁 Reranker test: {reranker_path}")
|
||||
|
||||
return {
|
||||
'retriever_test': retriever_path,
|
||||
'reranker_test': reranker_path
|
||||
}
|
||||
|
||||
|
||||
class ValidationReportAnalyzer:
|
||||
"""Analyze validation results and generate insights"""
|
||||
|
||||
def __init__(self, results_dir: str):
|
||||
self.results_dir = Path(results_dir)
|
||||
|
||||
def analyze_results(self) -> Dict[str, Any]:
|
||||
"""Analyze validation results"""
|
||||
analysis = {
|
||||
'overall_status': 'unknown',
|
||||
'model_performance': {},
|
||||
'recommendations': [],
|
||||
'summary_stats': {}
|
||||
}
|
||||
|
||||
# Look for results files
|
||||
json_results = self.results_dir / "validation_results.json"
|
||||
|
||||
if not json_results.exists():
|
||||
analysis['recommendations'].append("No validation results found. Run validation first.")
|
||||
return analysis
|
||||
|
||||
try:
|
||||
with open(json_results, 'r') as f:
|
||||
results = json.load(f)
|
||||
|
||||
# Analyze each model type
|
||||
improvements = 0
|
||||
degradations = 0
|
||||
|
||||
for model_type in ['retriever', 'reranker', 'pipeline']:
|
||||
if model_type in results and 'comparisons' in results[model_type]:
|
||||
model_analysis = self._analyze_model_results(results[model_type])
|
||||
analysis['model_performance'][model_type] = model_analysis
|
||||
|
||||
improvements += model_analysis['total_improvements']
|
||||
degradations += model_analysis['total_degradations']
|
||||
|
||||
# Overall status
|
||||
if improvements > degradations * 1.5:
|
||||
analysis['overall_status'] = 'excellent'
|
||||
elif improvements > degradations:
|
||||
analysis['overall_status'] = 'good'
|
||||
elif improvements == degradations:
|
||||
analysis['overall_status'] = 'mixed'
|
||||
else:
|
||||
analysis['overall_status'] = 'poor'
|
||||
|
||||
# Generate recommendations
|
||||
analysis['recommendations'] = self._generate_recommendations(analysis)
|
||||
|
||||
# Summary statistics
|
||||
analysis['summary_stats'] = {
|
||||
'total_improvements': improvements,
|
||||
'total_degradations': degradations,
|
||||
'net_improvement': improvements - degradations,
|
||||
'improvement_ratio': improvements / (improvements + degradations) if (improvements + degradations) > 0 else 0
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
analysis['recommendations'].append(f"Error analyzing results: {e}")
|
||||
|
||||
return analysis
|
||||
|
||||
def _analyze_model_results(self, model_results: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Analyze results for a specific model type"""
|
||||
analysis = {
|
||||
'total_improvements': 0,
|
||||
'total_degradations': 0,
|
||||
'best_dataset': None,
|
||||
'worst_dataset': None,
|
||||
'avg_improvement': 0,
|
||||
'issues': []
|
||||
}
|
||||
|
||||
dataset_scores = {}
|
||||
|
||||
for dataset_name, comparison in model_results.get('comparisons', {}).items():
|
||||
improvements = len(comparison.get('improvements', {}))
|
||||
degradations = len(comparison.get('degradations', {}))
|
||||
|
||||
analysis['total_improvements'] += improvements
|
||||
analysis['total_degradations'] += degradations
|
||||
|
||||
# Score dataset (improvements - degradations)
|
||||
dataset_score = improvements - degradations
|
||||
dataset_scores[dataset_name] = dataset_score
|
||||
|
||||
if dataset_scores:
|
||||
analysis['best_dataset'] = max(dataset_scores.items(), key=lambda x: x[1])
|
||||
analysis['worst_dataset'] = min(dataset_scores.items(), key=lambda x: x[1])
|
||||
analysis['avg_improvement'] = sum(dataset_scores.values()) / len(dataset_scores)
|
||||
|
||||
# Identify issues
|
||||
if analysis['total_degradations'] > analysis['total_improvements']:
|
||||
analysis['issues'].append("More degradations than improvements")
|
||||
|
||||
if analysis['avg_improvement'] < 0:
|
||||
analysis['issues'].append("Negative average improvement")
|
||||
|
||||
return analysis
|
||||
|
||||
def _generate_recommendations(self, analysis: Dict[str, Any]) -> List[str]:
|
||||
"""Generate recommendations based on analysis"""
|
||||
recommendations = []
|
||||
|
||||
overall_status = analysis['overall_status']
|
||||
|
||||
if overall_status == 'excellent':
|
||||
recommendations.append("🎉 Great job! Your fine-tuned models show significant improvements.")
|
||||
recommendations.append("Consider deploying these models to production.")
|
||||
|
||||
elif overall_status == 'good':
|
||||
recommendations.append("✅ Good progress! Models show overall improvement.")
|
||||
recommendations.append("Consider further fine-tuning on specific datasets where degradation occurred.")
|
||||
|
||||
elif overall_status == 'mixed':
|
||||
recommendations.append("⚠️ Mixed results. Some improvements, some degradations.")
|
||||
recommendations.append("Analyze which datasets/metrics are degrading and adjust training data.")
|
||||
recommendations.append("Consider different hyperparameters or longer training.")
|
||||
|
||||
elif overall_status == 'poor':
|
||||
recommendations.append("❌ Models show more degradations than improvements.")
|
||||
recommendations.append("Review training data quality and format.")
|
||||
recommendations.append("Check hyperparameters (learning rate might be too high).")
|
||||
recommendations.append("Consider starting with different base models.")
|
||||
|
||||
# Specific model recommendations
|
||||
for model_type, model_perf in analysis.get('model_performance', {}).items():
|
||||
if model_perf['issues']:
|
||||
recommendations.append(f"🔧 {model_type.title()} issues: {', '.join(model_perf['issues'])}")
|
||||
|
||||
if model_perf['best_dataset']:
|
||||
best_dataset, best_score = model_perf['best_dataset']
|
||||
recommendations.append(f"📊 {model_type.title()} performs best on: {best_dataset} (score: {best_score})")
|
||||
|
||||
if model_perf['worst_dataset']:
|
||||
worst_dataset, worst_score = model_perf['worst_dataset']
|
||||
recommendations.append(f"📉 {model_type.title()} needs improvement on: {worst_dataset} (score: {worst_score})")
|
||||
|
||||
return recommendations
|
||||
|
||||
def print_analysis(self):
|
||||
"""Print comprehensive analysis of validation results"""
|
||||
analysis = self.analyze_results()
|
||||
|
||||
print(f"\n📊 VALIDATION RESULTS ANALYSIS")
|
||||
print("=" * 60)
|
||||
|
||||
# Overall status
|
||||
status_emoji = {
|
||||
'excellent': '🌟',
|
||||
'good': '✅',
|
||||
'mixed': '⚠️',
|
||||
'poor': '❌',
|
||||
'unknown': '❓'
|
||||
}
|
||||
|
||||
print(f"{status_emoji.get(analysis['overall_status'], '❓')} Overall Status: {analysis['overall_status'].upper()}")
|
||||
|
||||
# Summary stats
|
||||
if 'summary_stats' in analysis and analysis['summary_stats']:
|
||||
stats = analysis['summary_stats']
|
||||
print(f"\n📈 Summary Statistics:")
|
||||
print(f" Total Improvements: {stats['total_improvements']}")
|
||||
print(f" Total Degradations: {stats['total_degradations']}")
|
||||
print(f" Net Improvement: {stats['net_improvement']:+d}")
|
||||
print(f" Improvement Ratio: {stats['improvement_ratio']:.2%}")
|
||||
|
||||
# Model performance
|
||||
if analysis['model_performance']:
|
||||
print(f"\n📊 Model Performance:")
|
||||
for model_type, perf in analysis['model_performance'].items():
|
||||
print(f" {model_type.title()}:")
|
||||
print(f" Improvements: {perf['total_improvements']}")
|
||||
print(f" Degradations: {perf['total_degradations']}")
|
||||
print(f" Average Score: {perf['avg_improvement']:.2f}")
|
||||
|
||||
if perf['best_dataset']:
|
||||
print(f" Best Dataset: {perf['best_dataset'][0]} (score: {perf['best_dataset'][1]})")
|
||||
|
||||
if perf['issues']:
|
||||
print(f" Issues: {', '.join(perf['issues'])}")
|
||||
|
||||
# Recommendations
|
||||
if analysis['recommendations']:
|
||||
print(f"\n💡 Recommendations:")
|
||||
for i, rec in enumerate(analysis['recommendations'], 1):
|
||||
print(f" {i}. {rec}")
|
||||
|
||||
return analysis
|
||||
|
||||
|
||||
def parse_args():
|
||||
"""Parse command line arguments"""
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Validation utilities for BGE models",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument("--discover-models", action="store_true",
|
||||
help="Discover trained models in workspace")
|
||||
parser.add_argument("--analyze-datasets", action="store_true",
|
||||
help="Analyze available test datasets")
|
||||
parser.add_argument("--create-sample-data", action="store_true",
|
||||
help="Create sample test datasets")
|
||||
parser.add_argument("--analyze-results", type=str, default=None,
|
||||
help="Analyze validation results in specified directory")
|
||||
parser.add_argument("--workspace-root", type=str, default=".",
|
||||
help="Root directory of workspace")
|
||||
parser.add_argument("--data-dir", type=str, default="data/datasets",
|
||||
help="Directory containing test datasets")
|
||||
parser.add_argument("--output-dir", type=str, default="data/test_samples",
|
||||
help="Output directory for sample data")
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main():
|
||||
"""Main utility function"""
|
||||
args = parse_args()
|
||||
|
||||
print("🔧 BGE Validation Utilities")
|
||||
print("=" * 50)
|
||||
|
||||
if args.discover_models:
|
||||
print("\n🔍 Discovering trained models...")
|
||||
discovery = ModelDiscovery(args.workspace_root)
|
||||
models = discovery.print_discovered_models()
|
||||
|
||||
if any(models.values()):
|
||||
print("\n💡 Use these model paths in your validation commands:")
|
||||
for model_type, model_dict in models.items():
|
||||
for name, info in model_dict.items():
|
||||
print(f" --{model_type[:-1]}_model {info['path']}")
|
||||
|
||||
if args.analyze_datasets:
|
||||
print("\n📁 Analyzing test datasets...")
|
||||
data_manager = TestDataManager(args.data_dir)
|
||||
datasets = data_manager.print_dataset_analysis()
|
||||
|
||||
if args.create_sample_data:
|
||||
print("\n📝 Creating sample test data...")
|
||||
data_manager = TestDataManager()
|
||||
sample_paths = data_manager.create_sample_test_data(args.output_dir)
|
||||
|
||||
print("\n💡 Use these sample datasets for testing:")
|
||||
print(f" --test_datasets {sample_paths['retriever_test']} {sample_paths['reranker_test']}")
|
||||
|
||||
if args.analyze_results:
|
||||
print(f"\n📊 Analyzing results from: {args.analyze_results}")
|
||||
analyzer = ValidationReportAnalyzer(args.analyze_results)
|
||||
analysis = analyzer.print_analysis()
|
||||
|
||||
if not any([args.discover_models, args.analyze_datasets, args.create_sample_data, args.analyze_results]):
|
||||
print("❓ No action specified. Use --help to see available options.")
|
||||
return 1
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
exit(main())
|
||||
Reference in New Issue
Block a user