624 lines
25 KiB
Python
624 lines
25 KiB
Python
#!/usr/bin/env python3
|
||
"""
|
||
Quick Validation Script for BGE Fine-tuned Models
|
||
|
||
This script provides fast validation to quickly check if fine-tuned models
|
||
are working correctly and performing better than baselines.
|
||
|
||
Usage:
|
||
# Quick test both models
|
||
python scripts/quick_validation.py \\
|
||
--retriever_model ./output/bge-m3-enhanced/final_model \\
|
||
--reranker_model ./output/bge-reranker/final_model
|
||
|
||
# Test only retriever
|
||
python scripts/quick_validation.py \\
|
||
--retriever_model ./output/bge-m3-enhanced/final_model
|
||
"""
|
||
|
||
import os
|
||
import sys
|
||
import json
|
||
import argparse
|
||
import logging
|
||
import torch
|
||
import numpy as np
|
||
from transformers import AutoTokenizer
|
||
from datetime import datetime
|
||
from pathlib import Path
|
||
|
||
# Add project root to path
|
||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||
|
||
from models.bge_m3 import BGEM3Model
|
||
from models.bge_reranker import BGERerankerModel
|
||
from utils.config_loader import get_config
|
||
|
||
logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s')
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
class QuickValidator:
|
||
"""Quick validation for BGE models"""
|
||
|
||
def __init__(self):
|
||
self.device = self._get_device()
|
||
self.test_queries = self._get_test_queries()
|
||
|
||
def _get_device(self):
|
||
"""Get optimal device"""
|
||
try:
|
||
config = get_config()
|
||
device_type = config.get('hardware', {}).get('device', 'cuda')
|
||
|
||
if device_type == 'npu' and torch.cuda.is_available():
|
||
try:
|
||
import torch_npu
|
||
return torch.device("npu:0")
|
||
except ImportError:
|
||
pass
|
||
|
||
if torch.cuda.is_available():
|
||
return torch.device("cuda")
|
||
|
||
except:
|
||
pass
|
||
|
||
return torch.device("cpu")
|
||
|
||
def _get_test_queries(self):
|
||
"""Get standard test queries for validation"""
|
||
return [
|
||
{
|
||
"query": "什么是深度学习?",
|
||
"relevant": [
|
||
"深度学习是机器学习的一个子领域,使用多层神经网络来学习数据表示",
|
||
"深度学习通过神经网络的多个层次来学习数据的抽象特征"
|
||
],
|
||
"irrelevant": [
|
||
"今天天气很好,阳光明媚",
|
||
"我喜欢吃苹果和香蕉"
|
||
]
|
||
},
|
||
{
|
||
"query": "如何制作红烧肉?",
|
||
"relevant": [
|
||
"红烧肉的制作需要五花肉、生抽、老抽、冰糖等材料,先炒糖色再炖煮",
|
||
"制作红烧肉的关键是掌握火候和糖色的调制"
|
||
],
|
||
"irrelevant": [
|
||
"Python是一种编程语言",
|
||
"机器学习需要大量的数据"
|
||
]
|
||
},
|
||
{
|
||
"query": "Python编程入门",
|
||
"relevant": [
|
||
"Python是一种易学易用的编程语言,适合初学者入门",
|
||
"学习Python需要掌握基本语法、数据类型和控制结构"
|
||
],
|
||
"irrelevant": [
|
||
"红烧肉是一道经典的中式菜肴",
|
||
"深度学习使用神经网络进行训练"
|
||
]
|
||
}
|
||
]
|
||
|
||
def validate_retriever(self, model_path, baseline_path="BAAI/bge-m3"):
|
||
"""Quick validation of retriever model"""
|
||
print(f"\n🔍 Validating Retriever Model")
|
||
print("=" * 50)
|
||
|
||
results = {
|
||
'model_path': model_path,
|
||
'baseline_path': baseline_path,
|
||
'test_results': {},
|
||
'summary': {}
|
||
}
|
||
|
||
try:
|
||
# Test fine-tuned model
|
||
print("Testing fine-tuned model...")
|
||
ft_scores = self._test_retriever_model(model_path, "Fine-tuned")
|
||
|
||
# Test baseline model
|
||
print("Testing baseline model...")
|
||
baseline_scores = self._test_retriever_model(baseline_path, "Baseline")
|
||
|
||
# Compare results
|
||
comparison = self._compare_retriever_results(baseline_scores, ft_scores)
|
||
|
||
results['test_results'] = {
|
||
'finetuned': ft_scores,
|
||
'baseline': baseline_scores,
|
||
'comparison': comparison
|
||
}
|
||
|
||
# Print summary
|
||
print(f"\n📊 Retriever Validation Summary:")
|
||
print(f" Baseline Avg Relevance Score: {baseline_scores['avg_relevant_score']:.4f}")
|
||
print(f" Fine-tuned Avg Relevance Score: {ft_scores['avg_relevant_score']:.4f}")
|
||
print(f" Improvement: {comparison['relevance_improvement']:.2f}%")
|
||
|
||
if comparison['relevance_improvement'] > 0:
|
||
print(" ✅ Fine-tuned model shows improvement!")
|
||
else:
|
||
print(" ❌ Fine-tuned model shows degradation!")
|
||
|
||
results['summary'] = {
|
||
'status': 'improved' if comparison['relevance_improvement'] > 0 else 'degraded',
|
||
'improvement_pct': comparison['relevance_improvement']
|
||
}
|
||
|
||
return results
|
||
|
||
except Exception as e:
|
||
logger.error(f"Retriever validation failed: {e}")
|
||
return None
|
||
|
||
def validate_reranker(self, model_path, baseline_path="BAAI/bge-reranker-base"):
|
||
"""Quick validation of reranker model"""
|
||
print(f"\n🏆 Validating Reranker Model")
|
||
print("=" * 50)
|
||
|
||
results = {
|
||
'model_path': model_path,
|
||
'baseline_path': baseline_path,
|
||
'test_results': {},
|
||
'summary': {}
|
||
}
|
||
|
||
try:
|
||
# Test fine-tuned model
|
||
print("Testing fine-tuned reranker...")
|
||
ft_scores = self._test_reranker_model(model_path, "Fine-tuned")
|
||
|
||
# Test baseline model
|
||
print("Testing baseline reranker...")
|
||
baseline_scores = self._test_reranker_model(baseline_path, "Baseline")
|
||
|
||
# Compare results
|
||
comparison = self._compare_reranker_results(baseline_scores, ft_scores)
|
||
|
||
results['test_results'] = {
|
||
'finetuned': ft_scores,
|
||
'baseline': baseline_scores,
|
||
'comparison': comparison
|
||
}
|
||
|
||
# Print summary
|
||
print(f"\n📊 Reranker Validation Summary:")
|
||
print(f" Baseline Accuracy: {baseline_scores['accuracy']:.4f}")
|
||
print(f" Fine-tuned Accuracy: {ft_scores['accuracy']:.4f}")
|
||
print(f" Improvement: {comparison['accuracy_improvement']:.2f}%")
|
||
|
||
if comparison['accuracy_improvement'] > 0:
|
||
print(" ✅ Fine-tuned reranker shows improvement!")
|
||
else:
|
||
print(" ❌ Fine-tuned reranker shows degradation!")
|
||
|
||
results['summary'] = {
|
||
'status': 'improved' if comparison['accuracy_improvement'] > 0 else 'degraded',
|
||
'improvement_pct': comparison['accuracy_improvement']
|
||
}
|
||
|
||
return results
|
||
|
||
except Exception as e:
|
||
logger.error(f"Reranker validation failed: {e}")
|
||
return None
|
||
|
||
def _test_retriever_model(self, model_path, model_type):
|
||
"""Test retriever model with predefined queries"""
|
||
try:
|
||
# Load model
|
||
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
||
model = BGEM3Model.from_pretrained(model_path)
|
||
model.to(self.device)
|
||
model.eval()
|
||
|
||
print(f" ✅ {model_type} model loaded successfully")
|
||
|
||
relevant_scores = []
|
||
irrelevant_scores = []
|
||
ranking_accuracy = []
|
||
|
||
with torch.no_grad():
|
||
for test_case in self.test_queries:
|
||
query = test_case["query"]
|
||
relevant_docs = test_case["relevant"]
|
||
irrelevant_docs = test_case["irrelevant"]
|
||
|
||
# Encode query
|
||
query_inputs = tokenizer(
|
||
query,
|
||
return_tensors='pt',
|
||
padding=True,
|
||
truncation=True,
|
||
max_length=512
|
||
).to(self.device)
|
||
|
||
query_outputs = model(**query_inputs, return_dense=True)
|
||
query_emb = query_outputs['dense']
|
||
|
||
# Test relevant documents
|
||
for doc in relevant_docs:
|
||
doc_inputs = tokenizer(
|
||
doc,
|
||
return_tensors='pt',
|
||
padding=True,
|
||
truncation=True,
|
||
max_length=512
|
||
).to(self.device)
|
||
|
||
doc_outputs = model(**doc_inputs, return_dense=True)
|
||
doc_emb = doc_outputs['dense']
|
||
|
||
# Compute similarity
|
||
similarity = torch.cosine_similarity(query_emb, doc_emb, dim=1).item()
|
||
relevant_scores.append(similarity)
|
||
|
||
# Test irrelevant documents
|
||
for doc in irrelevant_docs:
|
||
doc_inputs = tokenizer(
|
||
doc,
|
||
return_tensors='pt',
|
||
padding=True,
|
||
truncation=True,
|
||
max_length=512
|
||
).to(self.device)
|
||
|
||
doc_outputs = model(**doc_inputs, return_dense=True)
|
||
doc_emb = doc_outputs['dense']
|
||
|
||
similarity = torch.cosine_similarity(query_emb, doc_emb, dim=1).item()
|
||
irrelevant_scores.append(similarity)
|
||
|
||
# Check ranking accuracy (relevant should score higher than irrelevant)
|
||
query_relevant_avg = np.mean([torch.cosine_similarity(query_emb,
|
||
model(**tokenizer(doc, return_tensors='pt', padding=True, truncation=True, max_length=512).to(self.device), return_dense=True)['dense'], dim=1).item()
|
||
for doc in relevant_docs])
|
||
query_irrelevant_avg = np.mean([torch.cosine_similarity(query_emb,
|
||
model(**tokenizer(doc, return_tensors='pt', padding=True, truncation=True, max_length=512).to(self.device), return_dense=True)['dense'], dim=1).item()
|
||
for doc in irrelevant_docs])
|
||
|
||
ranking_accuracy.append(1.0 if query_relevant_avg > query_irrelevant_avg else 0.0)
|
||
|
||
return {
|
||
'avg_relevant_score': np.mean(relevant_scores),
|
||
'avg_irrelevant_score': np.mean(irrelevant_scores),
|
||
'ranking_accuracy': np.mean(ranking_accuracy),
|
||
'score_separation': np.mean(relevant_scores) - np.mean(irrelevant_scores)
|
||
}
|
||
|
||
except Exception as e:
|
||
logger.error(f"Error testing {model_type} retriever: {e}")
|
||
return None
|
||
|
||
def _test_reranker_model(self, model_path, model_type):
|
||
"""Test reranker model with predefined query-document pairs"""
|
||
try:
|
||
# Load model
|
||
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
||
model = BGERerankerModel.from_pretrained(model_path)
|
||
model.to(self.device)
|
||
model.eval()
|
||
|
||
print(f" ✅ {model_type} reranker loaded successfully")
|
||
|
||
correct_rankings = 0
|
||
total_pairs = 0
|
||
all_scores = []
|
||
all_labels = []
|
||
|
||
with torch.no_grad():
|
||
for test_case in self.test_queries:
|
||
query = test_case["query"]
|
||
relevant_docs = test_case["relevant"]
|
||
irrelevant_docs = test_case["irrelevant"]
|
||
|
||
# Test relevant documents (positive examples)
|
||
for doc in relevant_docs:
|
||
pair_text = f"{query} [SEP] {doc}"
|
||
inputs = tokenizer(
|
||
pair_text,
|
||
return_tensors='pt',
|
||
padding=True,
|
||
truncation=True,
|
||
max_length=512
|
||
).to(self.device)
|
||
|
||
outputs = model(**inputs)
|
||
score = outputs.logits.item()
|
||
all_scores.append(score)
|
||
all_labels.append(1) # Relevant
|
||
|
||
# Check if score indicates relevance (positive score)
|
||
if score > 0:
|
||
correct_rankings += 1
|
||
total_pairs += 1
|
||
|
||
# Test irrelevant documents (negative examples)
|
||
for doc in irrelevant_docs:
|
||
pair_text = f"{query} [SEP] {doc}"
|
||
inputs = tokenizer(
|
||
pair_text,
|
||
return_tensors='pt',
|
||
padding=True,
|
||
truncation=True,
|
||
max_length=512
|
||
).to(self.device)
|
||
|
||
outputs = model(**inputs)
|
||
score = outputs.logits.item()
|
||
all_scores.append(score)
|
||
all_labels.append(0) # Irrelevant
|
||
|
||
# Check if score indicates irrelevance (negative score)
|
||
if score <= 0:
|
||
correct_rankings += 1
|
||
total_pairs += 1
|
||
|
||
return {
|
||
'accuracy': correct_rankings / total_pairs if total_pairs > 0 else 0,
|
||
'avg_positive_score': np.mean([s for s, l in zip(all_scores, all_labels) if l == 1]),
|
||
'avg_negative_score': np.mean([s for s, l in zip(all_scores, all_labels) if l == 0]),
|
||
'total_pairs': total_pairs
|
||
}
|
||
|
||
except Exception as e:
|
||
logger.error(f"Error testing {model_type} reranker: {e}")
|
||
return None
|
||
|
||
def _compare_retriever_results(self, baseline, finetuned):
|
||
"""Compare retriever results"""
|
||
if not baseline or not finetuned:
|
||
return {}
|
||
|
||
relevance_improvement = ((finetuned['avg_relevant_score'] - baseline['avg_relevant_score']) /
|
||
baseline['avg_relevant_score'] * 100) if baseline['avg_relevant_score'] != 0 else 0
|
||
|
||
separation_improvement = ((finetuned['score_separation'] - baseline['score_separation']) /
|
||
baseline['score_separation'] * 100) if baseline['score_separation'] != 0 else 0
|
||
|
||
ranking_improvement = ((finetuned['ranking_accuracy'] - baseline['ranking_accuracy']) /
|
||
baseline['ranking_accuracy'] * 100) if baseline['ranking_accuracy'] != 0 else 0
|
||
|
||
return {
|
||
'relevance_improvement': relevance_improvement,
|
||
'separation_improvement': separation_improvement,
|
||
'ranking_improvement': ranking_improvement
|
||
}
|
||
|
||
def _compare_reranker_results(self, baseline, finetuned):
|
||
"""Compare reranker results"""
|
||
if not baseline or not finetuned:
|
||
return {}
|
||
|
||
accuracy_improvement = ((finetuned['accuracy'] - baseline['accuracy']) /
|
||
baseline['accuracy'] * 100) if baseline['accuracy'] != 0 else 0
|
||
|
||
return {
|
||
'accuracy_improvement': accuracy_improvement
|
||
}
|
||
|
||
def validate_pipeline(self, retriever_path, reranker_path):
|
||
"""Quick validation of full pipeline"""
|
||
print(f"\n🚀 Validating Full Pipeline")
|
||
print("=" * 50)
|
||
|
||
try:
|
||
# Load models
|
||
retriever_tokenizer = AutoTokenizer.from_pretrained(retriever_path, trust_remote_code=True)
|
||
retriever = BGEM3Model.from_pretrained(retriever_path)
|
||
retriever.to(self.device)
|
||
retriever.eval()
|
||
|
||
reranker_tokenizer = AutoTokenizer.from_pretrained(reranker_path, trust_remote_code=True)
|
||
reranker = BGERerankerModel.from_pretrained(reranker_path)
|
||
reranker.to(self.device)
|
||
reranker.eval()
|
||
|
||
print(" ✅ Pipeline models loaded successfully")
|
||
|
||
total_queries = 0
|
||
pipeline_accuracy = 0
|
||
|
||
with torch.no_grad():
|
||
for test_case in self.test_queries:
|
||
query = test_case["query"]
|
||
relevant_docs = test_case["relevant"]
|
||
irrelevant_docs = test_case["irrelevant"]
|
||
|
||
# Create candidate pool
|
||
all_docs = relevant_docs + irrelevant_docs
|
||
doc_labels = [1] * len(relevant_docs) + [0] * len(irrelevant_docs)
|
||
|
||
# Step 1: Retrieval
|
||
query_inputs = retriever_tokenizer(
|
||
query, return_tensors='pt', padding=True,
|
||
truncation=True, max_length=512
|
||
).to(self.device)
|
||
query_emb = retriever(**query_inputs, return_dense=True)['dense']
|
||
|
||
# Score all documents with retriever
|
||
doc_scores = []
|
||
for doc in all_docs:
|
||
doc_inputs = retriever_tokenizer(
|
||
doc, return_tensors='pt', padding=True,
|
||
truncation=True, max_length=512
|
||
).to(self.device)
|
||
doc_emb = retriever(**doc_inputs, return_dense=True)['dense']
|
||
score = torch.cosine_similarity(query_emb, doc_emb, dim=1).item()
|
||
doc_scores.append(score)
|
||
|
||
# Get top-k for reranking (all docs in this case)
|
||
retrieval_indices = sorted(range(len(doc_scores)),
|
||
key=lambda i: doc_scores[i], reverse=True)
|
||
|
||
# Step 2: Reranking
|
||
rerank_scores = []
|
||
rerank_labels = []
|
||
|
||
for idx in retrieval_indices:
|
||
doc = all_docs[idx]
|
||
label = doc_labels[idx]
|
||
|
||
pair_text = f"{query} [SEP] {doc}"
|
||
pair_inputs = reranker_tokenizer(
|
||
pair_text, return_tensors='pt', padding=True,
|
||
truncation=True, max_length=512
|
||
).to(self.device)
|
||
|
||
rerank_output = reranker(**pair_inputs)
|
||
rerank_score = rerank_output.logits.item()
|
||
|
||
rerank_scores.append(rerank_score)
|
||
rerank_labels.append(label)
|
||
|
||
# Check if top-ranked document is relevant
|
||
best_idx = np.argmax(rerank_scores)
|
||
if rerank_labels[best_idx] == 1:
|
||
pipeline_accuracy += 1
|
||
|
||
total_queries += 1
|
||
|
||
pipeline_acc = pipeline_accuracy / total_queries if total_queries > 0 else 0
|
||
|
||
print(f"\n📊 Pipeline Validation Summary:")
|
||
print(f" Pipeline Accuracy: {pipeline_acc:.4f}")
|
||
print(f" Queries Tested: {total_queries}")
|
||
|
||
if pipeline_acc > 0.5:
|
||
print(" ✅ Pipeline performing well!")
|
||
else:
|
||
print(" ❌ Pipeline needs improvement!")
|
||
|
||
return {
|
||
'accuracy': pipeline_acc,
|
||
'total_queries': total_queries,
|
||
'status': 'good' if pipeline_acc > 0.5 else 'needs_improvement'
|
||
}
|
||
|
||
except Exception as e:
|
||
logger.error(f"Pipeline validation failed: {e}")
|
||
return None
|
||
|
||
|
||
def parse_args():
|
||
"""Parse command line arguments"""
|
||
parser = argparse.ArgumentParser(
|
||
description="Quick validation of BGE fine-tuned models",
|
||
formatter_class=argparse.RawDescriptionHelpFormatter
|
||
)
|
||
|
||
parser.add_argument("--retriever_model", type=str, default=None,
|
||
help="Path to fine-tuned retriever model")
|
||
parser.add_argument("--reranker_model", type=str, default=None,
|
||
help="Path to fine-tuned reranker model")
|
||
parser.add_argument("--retriever_baseline", type=str, default="BAAI/bge-m3",
|
||
help="Path to baseline retriever model")
|
||
parser.add_argument("--reranker_baseline", type=str, default="BAAI/bge-reranker-base",
|
||
help="Path to baseline reranker model")
|
||
parser.add_argument("--test_pipeline", action="store_true",
|
||
help="Test full retrieval + reranking pipeline")
|
||
parser.add_argument("--output_file", type=str, default=None,
|
||
help="Save results to JSON file")
|
||
|
||
return parser.parse_args()
|
||
|
||
|
||
def main():
|
||
"""Main validation function"""
|
||
args = parse_args()
|
||
|
||
if not args.retriever_model and not args.reranker_model:
|
||
print("❌ Error: At least one model must be specified")
|
||
print(" Use --retriever_model or --reranker_model")
|
||
return 1
|
||
|
||
print("🚀 BGE Quick Validation")
|
||
print("=" * 60)
|
||
print(f"⏰ Started at: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
|
||
|
||
validator = QuickValidator()
|
||
results = {}
|
||
|
||
# Validate retriever
|
||
if args.retriever_model:
|
||
if os.path.exists(args.retriever_model):
|
||
retriever_results = validator.validate_retriever(
|
||
args.retriever_model, args.retriever_baseline
|
||
)
|
||
if retriever_results:
|
||
results['retriever'] = retriever_results
|
||
else:
|
||
print(f"❌ Retriever model not found: {args.retriever_model}")
|
||
|
||
# Validate reranker
|
||
if args.reranker_model:
|
||
if os.path.exists(args.reranker_model):
|
||
reranker_results = validator.validate_reranker(
|
||
args.reranker_model, args.reranker_baseline
|
||
)
|
||
if reranker_results:
|
||
results['reranker'] = reranker_results
|
||
else:
|
||
print(f"❌ Reranker model not found: {args.reranker_model}")
|
||
|
||
# Validate pipeline
|
||
if (args.test_pipeline and args.retriever_model and args.reranker_model and
|
||
os.path.exists(args.retriever_model) and os.path.exists(args.reranker_model)):
|
||
pipeline_results = validator.validate_pipeline(
|
||
args.retriever_model, args.reranker_model
|
||
)
|
||
if pipeline_results:
|
||
results['pipeline'] = pipeline_results
|
||
|
||
# Print final summary
|
||
print(f"\n{'='*60}")
|
||
print("🎯 FINAL VALIDATION SUMMARY")
|
||
print(f"{'='*60}")
|
||
|
||
overall_status = "✅ SUCCESS"
|
||
issues = []
|
||
|
||
for model_type, model_results in results.items():
|
||
if 'summary' in model_results:
|
||
status = model_results['summary'].get('status', 'unknown')
|
||
improvement = model_results['summary'].get('improvement_pct', 0)
|
||
|
||
if status == 'improved':
|
||
print(f"✅ {model_type.title()}: Improved by {improvement:.2f}%")
|
||
elif status == 'degraded':
|
||
print(f"❌ {model_type.title()}: Degraded by {abs(improvement):.2f}%")
|
||
issues.append(model_type)
|
||
overall_status = "⚠️ ISSUES DETECTED"
|
||
elif model_type == 'pipeline' and 'status' in model_results:
|
||
if model_results['status'] == 'good':
|
||
print(f"✅ Pipeline: Working well (accuracy: {model_results['accuracy']:.3f})")
|
||
else:
|
||
print(f"❌ Pipeline: Needs improvement (accuracy: {model_results['accuracy']:.3f})")
|
||
issues.append('pipeline')
|
||
overall_status = "⚠️ ISSUES DETECTED"
|
||
|
||
print(f"\n{overall_status}")
|
||
if issues:
|
||
print(f"⚠️ Issues detected in: {', '.join(issues)}")
|
||
print("💡 Consider running comprehensive validation for detailed analysis")
|
||
|
||
# Save results if requested
|
||
if args.output_file:
|
||
results['timestamp'] = datetime.now().isoformat()
|
||
results['overall_status'] = overall_status
|
||
results['issues'] = issues
|
||
|
||
with open(args.output_file, 'w', encoding='utf-8') as f:
|
||
json.dump(results, f, indent=2, ensure_ascii=False)
|
||
print(f"📊 Results saved to: {args.output_file}")
|
||
|
||
return 0 if overall_status == "✅ SUCCESS" else 1
|
||
|
||
|
||
if __name__ == "__main__":
|
||
exit(main()) |