bge_finetune/scripts/validation_utils.py
2025-07-22 16:55:25 +08:00

641 lines
26 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
"""
Validation Utilities for BGE Fine-tuning
This module provides utilities for model validation including:
- Automatic model discovery
- Test data preparation and validation
- Results analysis and reporting
- Benchmark dataset creation
Usage:
python scripts/validation_utils.py --discover-models
python scripts/validation_utils.py --prepare-test-data
python scripts/validation_utils.py --analyze-results ./validation_results
"""
import os
import sys
import json
import glob
import argparse
import logging
from pathlib import Path
from typing import List, Dict, Tuple, Optional, Any
from datetime import datetime
import pandas as pd
# Add project root to path
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s')
logger = logging.getLogger(__name__)
class ModelDiscovery:
"""Discover trained models in the workspace"""
def __init__(self, workspace_root: str = "."):
self.workspace_root = Path(workspace_root)
def find_trained_models(self) -> Dict[str, Dict[str, str]]:
"""Find all trained BGE models in the workspace"""
models = {
'retrievers': {},
'rerankers': {},
'joint': {}
}
# Common output directories to search
search_patterns = [
"output/*/final_model",
"output/*/best_model",
"output/*/*/final_model",
"output/*/*/best_model",
"test_*_output/final_model",
"training_output/*/final_model"
]
for pattern in search_patterns:
for model_path in glob.glob(str(self.workspace_root / pattern)):
model_path = Path(model_path)
if self._is_valid_model_directory(model_path):
model_info = self._analyze_model_directory(model_path)
if model_info['type'] == 'retriever':
models['retrievers'][model_info['name']] = model_info
elif model_info['type'] == 'reranker':
models['rerankers'][model_info['name']] = model_info
elif model_info['type'] == 'joint':
models['joint'][model_info['name']] = model_info
return models
def _is_valid_model_directory(self, path: Path) -> bool:
"""Check if directory contains a valid trained model"""
required_files = ['config.json', 'pytorch_model.bin']
alternative_files = ['model.safetensors', 'tokenizer.json']
has_required = any((path / f).exists() for f in required_files)
has_alternative = any((path / f).exists() for f in alternative_files)
return has_required or has_alternative
def _analyze_model_directory(self, path: Path) -> Dict[str, str]:
"""Analyze model directory to determine type and metadata"""
info = {
'path': str(path),
'name': self._generate_model_name(path),
'type': 'unknown',
'created': 'unknown',
'training_info': {}
}
# Try to determine model type from path
path_str = str(path).lower()
if 'm3' in path_str or 'retriever' in path_str:
info['type'] = 'retriever'
elif 'reranker' in path_str:
info['type'] = 'reranker'
elif 'joint' in path_str:
info['type'] = 'joint'
# Get creation time
if path.exists():
info['created'] = datetime.fromtimestamp(path.stat().st_mtime).strftime('%Y-%m-%d %H:%M:%S')
# Try to read training info
training_summary_path = path.parent / 'training_summary.json'
if training_summary_path.exists():
try:
with open(training_summary_path, 'r') as f:
info['training_info'] = json.load(f)
except:
pass
# Try to read config
config_path = path / 'config.json'
if config_path.exists():
try:
with open(config_path, 'r') as f:
config = json.load(f)
if 'model_type' in config:
info['type'] = config['model_type']
except:
pass
return info
def _generate_model_name(self, path: Path) -> str:
"""Generate a readable name for the model"""
parts = path.parts
# Try to find meaningful parts
meaningful_parts = []
for part in parts[-3:]: # Look at last 3 parts
if part not in ['final_model', 'best_model', 'output']:
meaningful_parts.append(part)
if meaningful_parts:
return '_'.join(meaningful_parts)
else:
return f"model_{path.parent.name}"
def print_discovered_models(self):
"""Print discovered models in a formatted way"""
models = self.find_trained_models()
print("\n🔍 DISCOVERED TRAINED MODELS")
print("=" * 60)
for model_type, model_dict in models.items():
if model_dict:
print(f"\n📊 {model_type.upper()}:")
for name, info in model_dict.items():
print(f"{name}")
print(f" Path: {info['path']}")
print(f" Created: {info['created']}")
if info['training_info']:
training_info = info['training_info']
if 'total_steps' in training_info:
print(f" Training Steps: {training_info['total_steps']}")
if 'final_metrics' in training_info and training_info['final_metrics']:
print(f" Final Metrics: {training_info['final_metrics']}")
if not any(models.values()):
print("❌ No trained models found!")
print("💡 Make sure you have trained some models first.")
return models
class TestDataManager:
"""Manage test datasets for validation"""
def __init__(self, data_dir: str = "data/datasets"):
self.data_dir = Path(data_dir)
def find_test_datasets(self) -> Dict[str, List[str]]:
"""Find available test datasets"""
datasets = {
'retriever_datasets': [],
'reranker_datasets': [],
'general_datasets': []
}
# Search for JSONL files in data directory
for jsonl_file in self.data_dir.rglob("*.jsonl"):
file_path = str(jsonl_file)
# Categorize based on filename/path
if 'reranker' in file_path.lower():
datasets['reranker_datasets'].append(file_path)
elif 'embedding' in file_path.lower() or 'm3' in file_path.lower():
datasets['retriever_datasets'].append(file_path)
else:
datasets['general_datasets'].append(file_path)
# Also search for JSON files
for json_file in self.data_dir.rglob("*.json"):
file_path = str(json_file)
if 'test' in file_path.lower() or 'dev' in file_path.lower():
if 'reranker' in file_path.lower() or 'AFQMC' in file_path:
datasets['reranker_datasets'].append(file_path)
else:
datasets['general_datasets'].append(file_path)
return datasets
def validate_dataset(self, dataset_path: str) -> Dict[str, Any]:
"""Validate a dataset file and analyze its format"""
validation_result = {
'path': dataset_path,
'exists': False,
'format': 'unknown',
'sample_count': 0,
'has_queries': False,
'has_passages': False,
'has_labels': False,
'issues': [],
'sample_data': {}
}
if not os.path.exists(dataset_path):
validation_result['issues'].append(f"File does not exist: {dataset_path}")
return validation_result
validation_result['exists'] = True
try:
# Read first few lines to analyze format
samples = []
with open(dataset_path, 'r', encoding='utf-8') as f:
for i, line in enumerate(f):
if i >= 5: # Only analyze first 5 lines
break
try:
sample = json.loads(line.strip())
samples.append(sample)
except json.JSONDecodeError:
validation_result['issues'].append(f"Invalid JSON at line {i+1}")
if samples:
validation_result['sample_count'] = len(samples)
validation_result['sample_data'] = samples[0] if samples else {}
# Analyze format
first_sample = samples[0]
# Check for common fields
if 'query' in first_sample:
validation_result['has_queries'] = True
if any(field in first_sample for field in ['passage', 'pos', 'neg', 'candidates']):
validation_result['has_passages'] = True
if any(field in first_sample for field in ['label', 'labels', 'score']):
validation_result['has_labels'] = True
# Determine format
if 'pos' in first_sample and 'neg' in first_sample:
validation_result['format'] = 'triplets'
elif 'candidates' in first_sample:
validation_result['format'] = 'candidates'
elif 'passage' in first_sample and 'label' in first_sample:
validation_result['format'] = 'pairs'
elif 'sentence1' in first_sample and 'sentence2' in first_sample:
validation_result['format'] = 'sentence_pairs'
else:
validation_result['format'] = 'custom'
except Exception as e:
validation_result['issues'].append(f"Error reading file: {e}")
return validation_result
def print_dataset_analysis(self):
"""Print analysis of available datasets"""
datasets = self.find_test_datasets()
print("\n📁 AVAILABLE TEST DATASETS")
print("=" * 60)
for category, dataset_list in datasets.items():
if dataset_list:
print(f"\n📊 {category.upper().replace('_', ' ')}:")
for dataset_path in dataset_list:
validation = self.validate_dataset(dataset_path)
status = "" if validation['exists'] and not validation['issues'] else "⚠️"
print(f" {status} {Path(dataset_path).name}")
print(f" Path: {dataset_path}")
print(f" Format: {validation['format']}")
if validation['issues']:
for issue in validation['issues']:
print(f" ⚠️ {issue}")
if not any(datasets.values()):
print("❌ No test datasets found!")
print("💡 Place your test data in the data/datasets/ directory")
return datasets
def create_sample_test_data(self, output_dir: str = "data/test_samples"):
"""Create sample test datasets for validation"""
os.makedirs(output_dir, exist_ok=True)
# Create sample retriever test data
retriever_test = [
{
"query": "什么是人工智能?",
"pos": [
"人工智能是计算机科学的一个分支,致力于创建能够模拟人类智能的系统",
"AI技术包括机器学习、深度学习、自然语言处理等多个领域"
],
"neg": [
"今天的天气非常好,阳光明媚",
"我喜欢听音乐和看电影"
]
},
{
"query": "如何学习编程?",
"pos": [
"学习编程需要选择合适的语言如Python、Java等然后通过实践项目提升技能",
"编程学习的关键是多写代码、多调试、多思考算法逻辑"
],
"neg": [
"烹饪是一门艺术,需要掌握火候和调味",
"运动对健康很重要每天应该锻炼30分钟"
]
}
]
# Create sample reranker test data
reranker_test = [
{
"query": "什么是人工智能?",
"passage": "人工智能是计算机科学的一个分支,致力于创建能够模拟人类智能的系统",
"label": 1
},
{
"query": "什么是人工智能?",
"passage": "今天的天气非常好,阳光明媚",
"label": 0
},
{
"query": "如何学习编程?",
"passage": "学习编程需要选择合适的语言如Python、Java等然后通过实践项目提升技能",
"label": 1
},
{
"query": "如何学习编程?",
"passage": "烹饪是一门艺术,需要掌握火候和调味",
"label": 0
}
]
# Save retriever test data
retriever_path = os.path.join(output_dir, "retriever_test.jsonl")
with open(retriever_path, 'w', encoding='utf-8') as f:
for item in retriever_test:
f.write(json.dumps(item, ensure_ascii=False) + '\n')
# Save reranker test data
reranker_path = os.path.join(output_dir, "reranker_test.jsonl")
with open(reranker_path, 'w', encoding='utf-8') as f:
for item in reranker_test:
f.write(json.dumps(item, ensure_ascii=False) + '\n')
print(f"✅ Sample test datasets created in: {output_dir}")
print(f" 📁 Retriever test: {retriever_path}")
print(f" 📁 Reranker test: {reranker_path}")
return {
'retriever_test': retriever_path,
'reranker_test': reranker_path
}
class ValidationReportAnalyzer:
"""Analyze validation results and generate insights"""
def __init__(self, results_dir: str):
self.results_dir = Path(results_dir)
def analyze_results(self) -> Dict[str, Any]:
"""Analyze validation results"""
analysis = {
'overall_status': 'unknown',
'model_performance': {},
'recommendations': [],
'summary_stats': {}
}
# Look for results files
json_results = self.results_dir / "validation_results.json"
if not json_results.exists():
analysis['recommendations'].append("No validation results found. Run validation first.")
return analysis
try:
with open(json_results, 'r') as f:
results = json.load(f)
# Analyze each model type
improvements = 0
degradations = 0
for model_type in ['retriever', 'reranker', 'pipeline']:
if model_type in results and 'comparisons' in results[model_type]:
model_analysis = self._analyze_model_results(results[model_type])
analysis['model_performance'][model_type] = model_analysis
improvements += model_analysis['total_improvements']
degradations += model_analysis['total_degradations']
# Overall status
if improvements > degradations * 1.5:
analysis['overall_status'] = 'excellent'
elif improvements > degradations:
analysis['overall_status'] = 'good'
elif improvements == degradations:
analysis['overall_status'] = 'mixed'
else:
analysis['overall_status'] = 'poor'
# Generate recommendations
analysis['recommendations'] = self._generate_recommendations(analysis)
# Summary statistics
analysis['summary_stats'] = {
'total_improvements': improvements,
'total_degradations': degradations,
'net_improvement': improvements - degradations,
'improvement_ratio': improvements / (improvements + degradations) if (improvements + degradations) > 0 else 0
}
except Exception as e:
analysis['recommendations'].append(f"Error analyzing results: {e}")
return analysis
def _analyze_model_results(self, model_results: Dict[str, Any]) -> Dict[str, Any]:
"""Analyze results for a specific model type"""
analysis = {
'total_improvements': 0,
'total_degradations': 0,
'best_dataset': None,
'worst_dataset': None,
'avg_improvement': 0,
'issues': []
}
dataset_scores = {}
for dataset_name, comparison in model_results.get('comparisons', {}).items():
improvements = len(comparison.get('improvements', {}))
degradations = len(comparison.get('degradations', {}))
analysis['total_improvements'] += improvements
analysis['total_degradations'] += degradations
# Score dataset (improvements - degradations)
dataset_score = improvements - degradations
dataset_scores[dataset_name] = dataset_score
if dataset_scores:
analysis['best_dataset'] = max(dataset_scores.items(), key=lambda x: x[1])
analysis['worst_dataset'] = min(dataset_scores.items(), key=lambda x: x[1])
analysis['avg_improvement'] = sum(dataset_scores.values()) / len(dataset_scores)
# Identify issues
if analysis['total_degradations'] > analysis['total_improvements']:
analysis['issues'].append("More degradations than improvements")
if analysis['avg_improvement'] < 0:
analysis['issues'].append("Negative average improvement")
return analysis
def _generate_recommendations(self, analysis: Dict[str, Any]) -> List[str]:
"""Generate recommendations based on analysis"""
recommendations = []
overall_status = analysis['overall_status']
if overall_status == 'excellent':
recommendations.append("🎉 Great job! Your fine-tuned models show significant improvements.")
recommendations.append("Consider deploying these models to production.")
elif overall_status == 'good':
recommendations.append("✅ Good progress! Models show overall improvement.")
recommendations.append("Consider further fine-tuning on specific datasets where degradation occurred.")
elif overall_status == 'mixed':
recommendations.append("⚠️ Mixed results. Some improvements, some degradations.")
recommendations.append("Analyze which datasets/metrics are degrading and adjust training data.")
recommendations.append("Consider different hyperparameters or longer training.")
elif overall_status == 'poor':
recommendations.append("❌ Models show more degradations than improvements.")
recommendations.append("Review training data quality and format.")
recommendations.append("Check hyperparameters (learning rate might be too high).")
recommendations.append("Consider starting with different base models.")
# Specific model recommendations
for model_type, model_perf in analysis.get('model_performance', {}).items():
if model_perf['issues']:
recommendations.append(f"🔧 {model_type.title()} issues: {', '.join(model_perf['issues'])}")
if model_perf['best_dataset']:
best_dataset, best_score = model_perf['best_dataset']
recommendations.append(f"📊 {model_type.title()} performs best on: {best_dataset} (score: {best_score})")
if model_perf['worst_dataset']:
worst_dataset, worst_score = model_perf['worst_dataset']
recommendations.append(f"📉 {model_type.title()} needs improvement on: {worst_dataset} (score: {worst_score})")
return recommendations
def print_analysis(self):
"""Print comprehensive analysis of validation results"""
analysis = self.analyze_results()
print(f"\n📊 VALIDATION RESULTS ANALYSIS")
print("=" * 60)
# Overall status
status_emoji = {
'excellent': '🌟',
'good': '',
'mixed': '⚠️',
'poor': '',
'unknown': ''
}
print(f"{status_emoji.get(analysis['overall_status'], '')} Overall Status: {analysis['overall_status'].upper()}")
# Summary stats
if 'summary_stats' in analysis and analysis['summary_stats']:
stats = analysis['summary_stats']
print(f"\n📈 Summary Statistics:")
print(f" Total Improvements: {stats['total_improvements']}")
print(f" Total Degradations: {stats['total_degradations']}")
print(f" Net Improvement: {stats['net_improvement']:+d}")
print(f" Improvement Ratio: {stats['improvement_ratio']:.2%}")
# Model performance
if analysis['model_performance']:
print(f"\n📊 Model Performance:")
for model_type, perf in analysis['model_performance'].items():
print(f" {model_type.title()}:")
print(f" Improvements: {perf['total_improvements']}")
print(f" Degradations: {perf['total_degradations']}")
print(f" Average Score: {perf['avg_improvement']:.2f}")
if perf['best_dataset']:
print(f" Best Dataset: {perf['best_dataset'][0]} (score: {perf['best_dataset'][1]})")
if perf['issues']:
print(f" Issues: {', '.join(perf['issues'])}")
# Recommendations
if analysis['recommendations']:
print(f"\n💡 Recommendations:")
for i, rec in enumerate(analysis['recommendations'], 1):
print(f" {i}. {rec}")
return analysis
def parse_args():
"""Parse command line arguments"""
parser = argparse.ArgumentParser(
description="Validation utilities for BGE models",
formatter_class=argparse.RawDescriptionHelpFormatter
)
parser.add_argument("--discover-models", action="store_true",
help="Discover trained models in workspace")
parser.add_argument("--analyze-datasets", action="store_true",
help="Analyze available test datasets")
parser.add_argument("--create-sample-data", action="store_true",
help="Create sample test datasets")
parser.add_argument("--analyze-results", type=str, default=None,
help="Analyze validation results in specified directory")
parser.add_argument("--workspace-root", type=str, default=".",
help="Root directory of workspace")
parser.add_argument("--data-dir", type=str, default="data/datasets",
help="Directory containing test datasets")
parser.add_argument("--output-dir", type=str, default="data/test_samples",
help="Output directory for sample data")
return parser.parse_args()
def main():
"""Main utility function"""
args = parse_args()
print("🔧 BGE Validation Utilities")
print("=" * 50)
if args.discover_models:
print("\n🔍 Discovering trained models...")
discovery = ModelDiscovery(args.workspace_root)
models = discovery.print_discovered_models()
if any(models.values()):
print("\n💡 Use these model paths in your validation commands:")
for model_type, model_dict in models.items():
for name, info in model_dict.items():
print(f" --{model_type[:-1]}_model {info['path']}")
if args.analyze_datasets:
print("\n📁 Analyzing test datasets...")
data_manager = TestDataManager(args.data_dir)
datasets = data_manager.print_dataset_analysis()
if args.create_sample_data:
print("\n📝 Creating sample test data...")
data_manager = TestDataManager()
sample_paths = data_manager.create_sample_test_data(args.output_dir)
print("\n💡 Use these sample datasets for testing:")
print(f" --test_datasets {sample_paths['retriever_test']} {sample_paths['reranker_test']}")
if args.analyze_results:
print(f"\n📊 Analyzing results from: {args.analyze_results}")
analyzer = ValidationReportAnalyzer(args.analyze_results)
analysis = analyzer.print_analysis()
if not any([args.discover_models, args.analyze_datasets, args.create_sample_data, args.analyze_results]):
print("❓ No action specified. Use --help to see available options.")
return 1
return 0
if __name__ == "__main__":
exit(main())