init commit
This commit is contained in:
624
scripts/quick_validation.py
Normal file
624
scripts/quick_validation.py
Normal file
@@ -0,0 +1,624 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Quick Validation Script for BGE Fine-tuned Models
|
||||
|
||||
This script provides fast validation to quickly check if fine-tuned models
|
||||
are working correctly and performing better than baselines.
|
||||
|
||||
Usage:
|
||||
# Quick test both models
|
||||
python scripts/quick_validation.py \\
|
||||
--retriever_model ./output/bge-m3-enhanced/final_model \\
|
||||
--reranker_model ./output/bge-reranker/final_model
|
||||
|
||||
# Test only retriever
|
||||
python scripts/quick_validation.py \\
|
||||
--retriever_model ./output/bge-m3-enhanced/final_model
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
import argparse
|
||||
import logging
|
||||
import torch
|
||||
import numpy as np
|
||||
from transformers import AutoTokenizer
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
# Add project root to path
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from models.bge_m3 import BGEM3Model
|
||||
from models.bge_reranker import BGERerankerModel
|
||||
from utils.config_loader import get_config
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s')
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class QuickValidator:
|
||||
"""Quick validation for BGE models"""
|
||||
|
||||
def __init__(self):
|
||||
self.device = self._get_device()
|
||||
self.test_queries = self._get_test_queries()
|
||||
|
||||
def _get_device(self):
|
||||
"""Get optimal device"""
|
||||
try:
|
||||
config = get_config()
|
||||
device_type = config.get('hardware', {}).get('device', 'cuda')
|
||||
|
||||
if device_type == 'npu' and torch.cuda.is_available():
|
||||
try:
|
||||
import torch_npu
|
||||
return torch.device("npu:0")
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
if torch.cuda.is_available():
|
||||
return torch.device("cuda")
|
||||
|
||||
except:
|
||||
pass
|
||||
|
||||
return torch.device("cpu")
|
||||
|
||||
def _get_test_queries(self):
|
||||
"""Get standard test queries for validation"""
|
||||
return [
|
||||
{
|
||||
"query": "什么是深度学习?",
|
||||
"relevant": [
|
||||
"深度学习是机器学习的一个子领域,使用多层神经网络来学习数据表示",
|
||||
"深度学习通过神经网络的多个层次来学习数据的抽象特征"
|
||||
],
|
||||
"irrelevant": [
|
||||
"今天天气很好,阳光明媚",
|
||||
"我喜欢吃苹果和香蕉"
|
||||
]
|
||||
},
|
||||
{
|
||||
"query": "如何制作红烧肉?",
|
||||
"relevant": [
|
||||
"红烧肉的制作需要五花肉、生抽、老抽、冰糖等材料,先炒糖色再炖煮",
|
||||
"制作红烧肉的关键是掌握火候和糖色的调制"
|
||||
],
|
||||
"irrelevant": [
|
||||
"Python是一种编程语言",
|
||||
"机器学习需要大量的数据"
|
||||
]
|
||||
},
|
||||
{
|
||||
"query": "Python编程入门",
|
||||
"relevant": [
|
||||
"Python是一种易学易用的编程语言,适合初学者入门",
|
||||
"学习Python需要掌握基本语法、数据类型和控制结构"
|
||||
],
|
||||
"irrelevant": [
|
||||
"红烧肉是一道经典的中式菜肴",
|
||||
"深度学习使用神经网络进行训练"
|
||||
]
|
||||
}
|
||||
]
|
||||
|
||||
def validate_retriever(self, model_path, baseline_path="BAAI/bge-m3"):
|
||||
"""Quick validation of retriever model"""
|
||||
print(f"\n🔍 Validating Retriever Model")
|
||||
print("=" * 50)
|
||||
|
||||
results = {
|
||||
'model_path': model_path,
|
||||
'baseline_path': baseline_path,
|
||||
'test_results': {},
|
||||
'summary': {}
|
||||
}
|
||||
|
||||
try:
|
||||
# Test fine-tuned model
|
||||
print("Testing fine-tuned model...")
|
||||
ft_scores = self._test_retriever_model(model_path, "Fine-tuned")
|
||||
|
||||
# Test baseline model
|
||||
print("Testing baseline model...")
|
||||
baseline_scores = self._test_retriever_model(baseline_path, "Baseline")
|
||||
|
||||
# Compare results
|
||||
comparison = self._compare_retriever_results(baseline_scores, ft_scores)
|
||||
|
||||
results['test_results'] = {
|
||||
'finetuned': ft_scores,
|
||||
'baseline': baseline_scores,
|
||||
'comparison': comparison
|
||||
}
|
||||
|
||||
# Print summary
|
||||
print(f"\n📊 Retriever Validation Summary:")
|
||||
print(f" Baseline Avg Relevance Score: {baseline_scores['avg_relevant_score']:.4f}")
|
||||
print(f" Fine-tuned Avg Relevance Score: {ft_scores['avg_relevant_score']:.4f}")
|
||||
print(f" Improvement: {comparison['relevance_improvement']:.2f}%")
|
||||
|
||||
if comparison['relevance_improvement'] > 0:
|
||||
print(" ✅ Fine-tuned model shows improvement!")
|
||||
else:
|
||||
print(" ❌ Fine-tuned model shows degradation!")
|
||||
|
||||
results['summary'] = {
|
||||
'status': 'improved' if comparison['relevance_improvement'] > 0 else 'degraded',
|
||||
'improvement_pct': comparison['relevance_improvement']
|
||||
}
|
||||
|
||||
return results
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Retriever validation failed: {e}")
|
||||
return None
|
||||
|
||||
def validate_reranker(self, model_path, baseline_path="BAAI/bge-reranker-base"):
|
||||
"""Quick validation of reranker model"""
|
||||
print(f"\n🏆 Validating Reranker Model")
|
||||
print("=" * 50)
|
||||
|
||||
results = {
|
||||
'model_path': model_path,
|
||||
'baseline_path': baseline_path,
|
||||
'test_results': {},
|
||||
'summary': {}
|
||||
}
|
||||
|
||||
try:
|
||||
# Test fine-tuned model
|
||||
print("Testing fine-tuned reranker...")
|
||||
ft_scores = self._test_reranker_model(model_path, "Fine-tuned")
|
||||
|
||||
# Test baseline model
|
||||
print("Testing baseline reranker...")
|
||||
baseline_scores = self._test_reranker_model(baseline_path, "Baseline")
|
||||
|
||||
# Compare results
|
||||
comparison = self._compare_reranker_results(baseline_scores, ft_scores)
|
||||
|
||||
results['test_results'] = {
|
||||
'finetuned': ft_scores,
|
||||
'baseline': baseline_scores,
|
||||
'comparison': comparison
|
||||
}
|
||||
|
||||
# Print summary
|
||||
print(f"\n📊 Reranker Validation Summary:")
|
||||
print(f" Baseline Accuracy: {baseline_scores['accuracy']:.4f}")
|
||||
print(f" Fine-tuned Accuracy: {ft_scores['accuracy']:.4f}")
|
||||
print(f" Improvement: {comparison['accuracy_improvement']:.2f}%")
|
||||
|
||||
if comparison['accuracy_improvement'] > 0:
|
||||
print(" ✅ Fine-tuned reranker shows improvement!")
|
||||
else:
|
||||
print(" ❌ Fine-tuned reranker shows degradation!")
|
||||
|
||||
results['summary'] = {
|
||||
'status': 'improved' if comparison['accuracy_improvement'] > 0 else 'degraded',
|
||||
'improvement_pct': comparison['accuracy_improvement']
|
||||
}
|
||||
|
||||
return results
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Reranker validation failed: {e}")
|
||||
return None
|
||||
|
||||
def _test_retriever_model(self, model_path, model_type):
|
||||
"""Test retriever model with predefined queries"""
|
||||
try:
|
||||
# Load model
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
||||
model = BGEM3Model.from_pretrained(model_path)
|
||||
model.to(self.device)
|
||||
model.eval()
|
||||
|
||||
print(f" ✅ {model_type} model loaded successfully")
|
||||
|
||||
relevant_scores = []
|
||||
irrelevant_scores = []
|
||||
ranking_accuracy = []
|
||||
|
||||
with torch.no_grad():
|
||||
for test_case in self.test_queries:
|
||||
query = test_case["query"]
|
||||
relevant_docs = test_case["relevant"]
|
||||
irrelevant_docs = test_case["irrelevant"]
|
||||
|
||||
# Encode query
|
||||
query_inputs = tokenizer(
|
||||
query,
|
||||
return_tensors='pt',
|
||||
padding=True,
|
||||
truncation=True,
|
||||
max_length=512
|
||||
).to(self.device)
|
||||
|
||||
query_outputs = model(**query_inputs, return_dense=True)
|
||||
query_emb = query_outputs['dense']
|
||||
|
||||
# Test relevant documents
|
||||
for doc in relevant_docs:
|
||||
doc_inputs = tokenizer(
|
||||
doc,
|
||||
return_tensors='pt',
|
||||
padding=True,
|
||||
truncation=True,
|
||||
max_length=512
|
||||
).to(self.device)
|
||||
|
||||
doc_outputs = model(**doc_inputs, return_dense=True)
|
||||
doc_emb = doc_outputs['dense']
|
||||
|
||||
# Compute similarity
|
||||
similarity = torch.cosine_similarity(query_emb, doc_emb, dim=1).item()
|
||||
relevant_scores.append(similarity)
|
||||
|
||||
# Test irrelevant documents
|
||||
for doc in irrelevant_docs:
|
||||
doc_inputs = tokenizer(
|
||||
doc,
|
||||
return_tensors='pt',
|
||||
padding=True,
|
||||
truncation=True,
|
||||
max_length=512
|
||||
).to(self.device)
|
||||
|
||||
doc_outputs = model(**doc_inputs, return_dense=True)
|
||||
doc_emb = doc_outputs['dense']
|
||||
|
||||
similarity = torch.cosine_similarity(query_emb, doc_emb, dim=1).item()
|
||||
irrelevant_scores.append(similarity)
|
||||
|
||||
# Check ranking accuracy (relevant should score higher than irrelevant)
|
||||
query_relevant_avg = np.mean([torch.cosine_similarity(query_emb,
|
||||
model(**tokenizer(doc, return_tensors='pt', padding=True, truncation=True, max_length=512).to(self.device), return_dense=True)['dense'], dim=1).item()
|
||||
for doc in relevant_docs])
|
||||
query_irrelevant_avg = np.mean([torch.cosine_similarity(query_emb,
|
||||
model(**tokenizer(doc, return_tensors='pt', padding=True, truncation=True, max_length=512).to(self.device), return_dense=True)['dense'], dim=1).item()
|
||||
for doc in irrelevant_docs])
|
||||
|
||||
ranking_accuracy.append(1.0 if query_relevant_avg > query_irrelevant_avg else 0.0)
|
||||
|
||||
return {
|
||||
'avg_relevant_score': np.mean(relevant_scores),
|
||||
'avg_irrelevant_score': np.mean(irrelevant_scores),
|
||||
'ranking_accuracy': np.mean(ranking_accuracy),
|
||||
'score_separation': np.mean(relevant_scores) - np.mean(irrelevant_scores)
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error testing {model_type} retriever: {e}")
|
||||
return None
|
||||
|
||||
def _test_reranker_model(self, model_path, model_type):
|
||||
"""Test reranker model with predefined query-document pairs"""
|
||||
try:
|
||||
# Load model
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
||||
model = BGERerankerModel.from_pretrained(model_path)
|
||||
model.to(self.device)
|
||||
model.eval()
|
||||
|
||||
print(f" ✅ {model_type} reranker loaded successfully")
|
||||
|
||||
correct_rankings = 0
|
||||
total_pairs = 0
|
||||
all_scores = []
|
||||
all_labels = []
|
||||
|
||||
with torch.no_grad():
|
||||
for test_case in self.test_queries:
|
||||
query = test_case["query"]
|
||||
relevant_docs = test_case["relevant"]
|
||||
irrelevant_docs = test_case["irrelevant"]
|
||||
|
||||
# Test relevant documents (positive examples)
|
||||
for doc in relevant_docs:
|
||||
pair_text = f"{query} [SEP] {doc}"
|
||||
inputs = tokenizer(
|
||||
pair_text,
|
||||
return_tensors='pt',
|
||||
padding=True,
|
||||
truncation=True,
|
||||
max_length=512
|
||||
).to(self.device)
|
||||
|
||||
outputs = model(**inputs)
|
||||
score = outputs.logits.item()
|
||||
all_scores.append(score)
|
||||
all_labels.append(1) # Relevant
|
||||
|
||||
# Check if score indicates relevance (positive score)
|
||||
if score > 0:
|
||||
correct_rankings += 1
|
||||
total_pairs += 1
|
||||
|
||||
# Test irrelevant documents (negative examples)
|
||||
for doc in irrelevant_docs:
|
||||
pair_text = f"{query} [SEP] {doc}"
|
||||
inputs = tokenizer(
|
||||
pair_text,
|
||||
return_tensors='pt',
|
||||
padding=True,
|
||||
truncation=True,
|
||||
max_length=512
|
||||
).to(self.device)
|
||||
|
||||
outputs = model(**inputs)
|
||||
score = outputs.logits.item()
|
||||
all_scores.append(score)
|
||||
all_labels.append(0) # Irrelevant
|
||||
|
||||
# Check if score indicates irrelevance (negative score)
|
||||
if score <= 0:
|
||||
correct_rankings += 1
|
||||
total_pairs += 1
|
||||
|
||||
return {
|
||||
'accuracy': correct_rankings / total_pairs if total_pairs > 0 else 0,
|
||||
'avg_positive_score': np.mean([s for s, l in zip(all_scores, all_labels) if l == 1]),
|
||||
'avg_negative_score': np.mean([s for s, l in zip(all_scores, all_labels) if l == 0]),
|
||||
'total_pairs': total_pairs
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error testing {model_type} reranker: {e}")
|
||||
return None
|
||||
|
||||
def _compare_retriever_results(self, baseline, finetuned):
|
||||
"""Compare retriever results"""
|
||||
if not baseline or not finetuned:
|
||||
return {}
|
||||
|
||||
relevance_improvement = ((finetuned['avg_relevant_score'] - baseline['avg_relevant_score']) /
|
||||
baseline['avg_relevant_score'] * 100) if baseline['avg_relevant_score'] != 0 else 0
|
||||
|
||||
separation_improvement = ((finetuned['score_separation'] - baseline['score_separation']) /
|
||||
baseline['score_separation'] * 100) if baseline['score_separation'] != 0 else 0
|
||||
|
||||
ranking_improvement = ((finetuned['ranking_accuracy'] - baseline['ranking_accuracy']) /
|
||||
baseline['ranking_accuracy'] * 100) if baseline['ranking_accuracy'] != 0 else 0
|
||||
|
||||
return {
|
||||
'relevance_improvement': relevance_improvement,
|
||||
'separation_improvement': separation_improvement,
|
||||
'ranking_improvement': ranking_improvement
|
||||
}
|
||||
|
||||
def _compare_reranker_results(self, baseline, finetuned):
|
||||
"""Compare reranker results"""
|
||||
if not baseline or not finetuned:
|
||||
return {}
|
||||
|
||||
accuracy_improvement = ((finetuned['accuracy'] - baseline['accuracy']) /
|
||||
baseline['accuracy'] * 100) if baseline['accuracy'] != 0 else 0
|
||||
|
||||
return {
|
||||
'accuracy_improvement': accuracy_improvement
|
||||
}
|
||||
|
||||
def validate_pipeline(self, retriever_path, reranker_path):
|
||||
"""Quick validation of full pipeline"""
|
||||
print(f"\n🚀 Validating Full Pipeline")
|
||||
print("=" * 50)
|
||||
|
||||
try:
|
||||
# Load models
|
||||
retriever_tokenizer = AutoTokenizer.from_pretrained(retriever_path, trust_remote_code=True)
|
||||
retriever = BGEM3Model.from_pretrained(retriever_path)
|
||||
retriever.to(self.device)
|
||||
retriever.eval()
|
||||
|
||||
reranker_tokenizer = AutoTokenizer.from_pretrained(reranker_path, trust_remote_code=True)
|
||||
reranker = BGERerankerModel.from_pretrained(reranker_path)
|
||||
reranker.to(self.device)
|
||||
reranker.eval()
|
||||
|
||||
print(" ✅ Pipeline models loaded successfully")
|
||||
|
||||
total_queries = 0
|
||||
pipeline_accuracy = 0
|
||||
|
||||
with torch.no_grad():
|
||||
for test_case in self.test_queries:
|
||||
query = test_case["query"]
|
||||
relevant_docs = test_case["relevant"]
|
||||
irrelevant_docs = test_case["irrelevant"]
|
||||
|
||||
# Create candidate pool
|
||||
all_docs = relevant_docs + irrelevant_docs
|
||||
doc_labels = [1] * len(relevant_docs) + [0] * len(irrelevant_docs)
|
||||
|
||||
# Step 1: Retrieval
|
||||
query_inputs = retriever_tokenizer(
|
||||
query, return_tensors='pt', padding=True,
|
||||
truncation=True, max_length=512
|
||||
).to(self.device)
|
||||
query_emb = retriever(**query_inputs, return_dense=True)['dense']
|
||||
|
||||
# Score all documents with retriever
|
||||
doc_scores = []
|
||||
for doc in all_docs:
|
||||
doc_inputs = retriever_tokenizer(
|
||||
doc, return_tensors='pt', padding=True,
|
||||
truncation=True, max_length=512
|
||||
).to(self.device)
|
||||
doc_emb = retriever(**doc_inputs, return_dense=True)['dense']
|
||||
score = torch.cosine_similarity(query_emb, doc_emb, dim=1).item()
|
||||
doc_scores.append(score)
|
||||
|
||||
# Get top-k for reranking (all docs in this case)
|
||||
retrieval_indices = sorted(range(len(doc_scores)),
|
||||
key=lambda i: doc_scores[i], reverse=True)
|
||||
|
||||
# Step 2: Reranking
|
||||
rerank_scores = []
|
||||
rerank_labels = []
|
||||
|
||||
for idx in retrieval_indices:
|
||||
doc = all_docs[idx]
|
||||
label = doc_labels[idx]
|
||||
|
||||
pair_text = f"{query} [SEP] {doc}"
|
||||
pair_inputs = reranker_tokenizer(
|
||||
pair_text, return_tensors='pt', padding=True,
|
||||
truncation=True, max_length=512
|
||||
).to(self.device)
|
||||
|
||||
rerank_output = reranker(**pair_inputs)
|
||||
rerank_score = rerank_output.logits.item()
|
||||
|
||||
rerank_scores.append(rerank_score)
|
||||
rerank_labels.append(label)
|
||||
|
||||
# Check if top-ranked document is relevant
|
||||
best_idx = np.argmax(rerank_scores)
|
||||
if rerank_labels[best_idx] == 1:
|
||||
pipeline_accuracy += 1
|
||||
|
||||
total_queries += 1
|
||||
|
||||
pipeline_acc = pipeline_accuracy / total_queries if total_queries > 0 else 0
|
||||
|
||||
print(f"\n📊 Pipeline Validation Summary:")
|
||||
print(f" Pipeline Accuracy: {pipeline_acc:.4f}")
|
||||
print(f" Queries Tested: {total_queries}")
|
||||
|
||||
if pipeline_acc > 0.5:
|
||||
print(" ✅ Pipeline performing well!")
|
||||
else:
|
||||
print(" ❌ Pipeline needs improvement!")
|
||||
|
||||
return {
|
||||
'accuracy': pipeline_acc,
|
||||
'total_queries': total_queries,
|
||||
'status': 'good' if pipeline_acc > 0.5 else 'needs_improvement'
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Pipeline validation failed: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def parse_args():
|
||||
"""Parse command line arguments"""
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Quick validation of BGE fine-tuned models",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument("--retriever_model", type=str, default=None,
|
||||
help="Path to fine-tuned retriever model")
|
||||
parser.add_argument("--reranker_model", type=str, default=None,
|
||||
help="Path to fine-tuned reranker model")
|
||||
parser.add_argument("--retriever_baseline", type=str, default="BAAI/bge-m3",
|
||||
help="Path to baseline retriever model")
|
||||
parser.add_argument("--reranker_baseline", type=str, default="BAAI/bge-reranker-base",
|
||||
help="Path to baseline reranker model")
|
||||
parser.add_argument("--test_pipeline", action="store_true",
|
||||
help="Test full retrieval + reranking pipeline")
|
||||
parser.add_argument("--output_file", type=str, default=None,
|
||||
help="Save results to JSON file")
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main():
|
||||
"""Main validation function"""
|
||||
args = parse_args()
|
||||
|
||||
if not args.retriever_model and not args.reranker_model:
|
||||
print("❌ Error: At least one model must be specified")
|
||||
print(" Use --retriever_model or --reranker_model")
|
||||
return 1
|
||||
|
||||
print("🚀 BGE Quick Validation")
|
||||
print("=" * 60)
|
||||
print(f"⏰ Started at: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
|
||||
|
||||
validator = QuickValidator()
|
||||
results = {}
|
||||
|
||||
# Validate retriever
|
||||
if args.retriever_model:
|
||||
if os.path.exists(args.retriever_model):
|
||||
retriever_results = validator.validate_retriever(
|
||||
args.retriever_model, args.retriever_baseline
|
||||
)
|
||||
if retriever_results:
|
||||
results['retriever'] = retriever_results
|
||||
else:
|
||||
print(f"❌ Retriever model not found: {args.retriever_model}")
|
||||
|
||||
# Validate reranker
|
||||
if args.reranker_model:
|
||||
if os.path.exists(args.reranker_model):
|
||||
reranker_results = validator.validate_reranker(
|
||||
args.reranker_model, args.reranker_baseline
|
||||
)
|
||||
if reranker_results:
|
||||
results['reranker'] = reranker_results
|
||||
else:
|
||||
print(f"❌ Reranker model not found: {args.reranker_model}")
|
||||
|
||||
# Validate pipeline
|
||||
if (args.test_pipeline and args.retriever_model and args.reranker_model and
|
||||
os.path.exists(args.retriever_model) and os.path.exists(args.reranker_model)):
|
||||
pipeline_results = validator.validate_pipeline(
|
||||
args.retriever_model, args.reranker_model
|
||||
)
|
||||
if pipeline_results:
|
||||
results['pipeline'] = pipeline_results
|
||||
|
||||
# Print final summary
|
||||
print(f"\n{'='*60}")
|
||||
print("🎯 FINAL VALIDATION SUMMARY")
|
||||
print(f"{'='*60}")
|
||||
|
||||
overall_status = "✅ SUCCESS"
|
||||
issues = []
|
||||
|
||||
for model_type, model_results in results.items():
|
||||
if 'summary' in model_results:
|
||||
status = model_results['summary'].get('status', 'unknown')
|
||||
improvement = model_results['summary'].get('improvement_pct', 0)
|
||||
|
||||
if status == 'improved':
|
||||
print(f"✅ {model_type.title()}: Improved by {improvement:.2f}%")
|
||||
elif status == 'degraded':
|
||||
print(f"❌ {model_type.title()}: Degraded by {abs(improvement):.2f}%")
|
||||
issues.append(model_type)
|
||||
overall_status = "⚠️ ISSUES DETECTED"
|
||||
elif model_type == 'pipeline' and 'status' in model_results:
|
||||
if model_results['status'] == 'good':
|
||||
print(f"✅ Pipeline: Working well (accuracy: {model_results['accuracy']:.3f})")
|
||||
else:
|
||||
print(f"❌ Pipeline: Needs improvement (accuracy: {model_results['accuracy']:.3f})")
|
||||
issues.append('pipeline')
|
||||
overall_status = "⚠️ ISSUES DETECTED"
|
||||
|
||||
print(f"\n{overall_status}")
|
||||
if issues:
|
||||
print(f"⚠️ Issues detected in: {', '.join(issues)}")
|
||||
print("💡 Consider running comprehensive validation for detailed analysis")
|
||||
|
||||
# Save results if requested
|
||||
if args.output_file:
|
||||
results['timestamp'] = datetime.now().isoformat()
|
||||
results['overall_status'] = overall_status
|
||||
results['issues'] = issues
|
||||
|
||||
with open(args.output_file, 'w', encoding='utf-8') as f:
|
||||
json.dump(results, f, indent=2, ensure_ascii=False)
|
||||
print(f"📊 Results saved to: {args.output_file}")
|
||||
|
||||
return 0 if overall_status == "✅ SUCCESS" else 1
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
exit(main())
|
||||
Reference in New Issue
Block a user