bge_finetune/scripts/quick_validation.py
2025-07-22 16:55:25 +08:00

624 lines
25 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
"""
Quick Validation Script for BGE Fine-tuned Models
This script provides fast validation to quickly check if fine-tuned models
are working correctly and performing better than baselines.
Usage:
# Quick test both models
python scripts/quick_validation.py \\
--retriever_model ./output/bge-m3-enhanced/final_model \\
--reranker_model ./output/bge-reranker/final_model
# Test only retriever
python scripts/quick_validation.py \\
--retriever_model ./output/bge-m3-enhanced/final_model
"""
import os
import sys
import json
import argparse
import logging
import torch
import numpy as np
from transformers import AutoTokenizer
from datetime import datetime
from pathlib import Path
# Add project root to path
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from models.bge_m3 import BGEM3Model
from models.bge_reranker import BGERerankerModel
from utils.config_loader import get_config
logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s')
logger = logging.getLogger(__name__)
class QuickValidator:
"""Quick validation for BGE models"""
def __init__(self):
self.device = self._get_device()
self.test_queries = self._get_test_queries()
def _get_device(self):
"""Get optimal device"""
try:
config = get_config()
device_type = config.get('hardware', {}).get('device', 'cuda')
if device_type == 'npu' and torch.cuda.is_available():
try:
import torch_npu
return torch.device("npu:0")
except ImportError:
pass
if torch.cuda.is_available():
return torch.device("cuda")
except:
pass
return torch.device("cpu")
def _get_test_queries(self):
"""Get standard test queries for validation"""
return [
{
"query": "什么是深度学习?",
"relevant": [
"深度学习是机器学习的一个子领域,使用多层神经网络来学习数据表示",
"深度学习通过神经网络的多个层次来学习数据的抽象特征"
],
"irrelevant": [
"今天天气很好,阳光明媚",
"我喜欢吃苹果和香蕉"
]
},
{
"query": "如何制作红烧肉?",
"relevant": [
"红烧肉的制作需要五花肉、生抽、老抽、冰糖等材料,先炒糖色再炖煮",
"制作红烧肉的关键是掌握火候和糖色的调制"
],
"irrelevant": [
"Python是一种编程语言",
"机器学习需要大量的数据"
]
},
{
"query": "Python编程入门",
"relevant": [
"Python是一种易学易用的编程语言适合初学者入门",
"学习Python需要掌握基本语法、数据类型和控制结构"
],
"irrelevant": [
"红烧肉是一道经典的中式菜肴",
"深度学习使用神经网络进行训练"
]
}
]
def validate_retriever(self, model_path, baseline_path="BAAI/bge-m3"):
"""Quick validation of retriever model"""
print(f"\n🔍 Validating Retriever Model")
print("=" * 50)
results = {
'model_path': model_path,
'baseline_path': baseline_path,
'test_results': {},
'summary': {}
}
try:
# Test fine-tuned model
print("Testing fine-tuned model...")
ft_scores = self._test_retriever_model(model_path, "Fine-tuned")
# Test baseline model
print("Testing baseline model...")
baseline_scores = self._test_retriever_model(baseline_path, "Baseline")
# Compare results
comparison = self._compare_retriever_results(baseline_scores, ft_scores)
results['test_results'] = {
'finetuned': ft_scores,
'baseline': baseline_scores,
'comparison': comparison
}
# Print summary
print(f"\n📊 Retriever Validation Summary:")
print(f" Baseline Avg Relevance Score: {baseline_scores['avg_relevant_score']:.4f}")
print(f" Fine-tuned Avg Relevance Score: {ft_scores['avg_relevant_score']:.4f}")
print(f" Improvement: {comparison['relevance_improvement']:.2f}%")
if comparison['relevance_improvement'] > 0:
print(" ✅ Fine-tuned model shows improvement!")
else:
print(" ❌ Fine-tuned model shows degradation!")
results['summary'] = {
'status': 'improved' if comparison['relevance_improvement'] > 0 else 'degraded',
'improvement_pct': comparison['relevance_improvement']
}
return results
except Exception as e:
logger.error(f"Retriever validation failed: {e}")
return None
def validate_reranker(self, model_path, baseline_path="BAAI/bge-reranker-base"):
"""Quick validation of reranker model"""
print(f"\n🏆 Validating Reranker Model")
print("=" * 50)
results = {
'model_path': model_path,
'baseline_path': baseline_path,
'test_results': {},
'summary': {}
}
try:
# Test fine-tuned model
print("Testing fine-tuned reranker...")
ft_scores = self._test_reranker_model(model_path, "Fine-tuned")
# Test baseline model
print("Testing baseline reranker...")
baseline_scores = self._test_reranker_model(baseline_path, "Baseline")
# Compare results
comparison = self._compare_reranker_results(baseline_scores, ft_scores)
results['test_results'] = {
'finetuned': ft_scores,
'baseline': baseline_scores,
'comparison': comparison
}
# Print summary
print(f"\n📊 Reranker Validation Summary:")
print(f" Baseline Accuracy: {baseline_scores['accuracy']:.4f}")
print(f" Fine-tuned Accuracy: {ft_scores['accuracy']:.4f}")
print(f" Improvement: {comparison['accuracy_improvement']:.2f}%")
if comparison['accuracy_improvement'] > 0:
print(" ✅ Fine-tuned reranker shows improvement!")
else:
print(" ❌ Fine-tuned reranker shows degradation!")
results['summary'] = {
'status': 'improved' if comparison['accuracy_improvement'] > 0 else 'degraded',
'improvement_pct': comparison['accuracy_improvement']
}
return results
except Exception as e:
logger.error(f"Reranker validation failed: {e}")
return None
def _test_retriever_model(self, model_path, model_type):
"""Test retriever model with predefined queries"""
try:
# Load model
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
model = BGEM3Model.from_pretrained(model_path)
model.to(self.device)
model.eval()
print(f"{model_type} model loaded successfully")
relevant_scores = []
irrelevant_scores = []
ranking_accuracy = []
with torch.no_grad():
for test_case in self.test_queries:
query = test_case["query"]
relevant_docs = test_case["relevant"]
irrelevant_docs = test_case["irrelevant"]
# Encode query
query_inputs = tokenizer(
query,
return_tensors='pt',
padding=True,
truncation=True,
max_length=512
).to(self.device)
query_outputs = model(**query_inputs, return_dense=True)
query_emb = query_outputs['dense']
# Test relevant documents
for doc in relevant_docs:
doc_inputs = tokenizer(
doc,
return_tensors='pt',
padding=True,
truncation=True,
max_length=512
).to(self.device)
doc_outputs = model(**doc_inputs, return_dense=True)
doc_emb = doc_outputs['dense']
# Compute similarity
similarity = torch.cosine_similarity(query_emb, doc_emb, dim=1).item()
relevant_scores.append(similarity)
# Test irrelevant documents
for doc in irrelevant_docs:
doc_inputs = tokenizer(
doc,
return_tensors='pt',
padding=True,
truncation=True,
max_length=512
).to(self.device)
doc_outputs = model(**doc_inputs, return_dense=True)
doc_emb = doc_outputs['dense']
similarity = torch.cosine_similarity(query_emb, doc_emb, dim=1).item()
irrelevant_scores.append(similarity)
# Check ranking accuracy (relevant should score higher than irrelevant)
query_relevant_avg = np.mean([torch.cosine_similarity(query_emb,
model(**tokenizer(doc, return_tensors='pt', padding=True, truncation=True, max_length=512).to(self.device), return_dense=True)['dense'], dim=1).item()
for doc in relevant_docs])
query_irrelevant_avg = np.mean([torch.cosine_similarity(query_emb,
model(**tokenizer(doc, return_tensors='pt', padding=True, truncation=True, max_length=512).to(self.device), return_dense=True)['dense'], dim=1).item()
for doc in irrelevant_docs])
ranking_accuracy.append(1.0 if query_relevant_avg > query_irrelevant_avg else 0.0)
return {
'avg_relevant_score': np.mean(relevant_scores),
'avg_irrelevant_score': np.mean(irrelevant_scores),
'ranking_accuracy': np.mean(ranking_accuracy),
'score_separation': np.mean(relevant_scores) - np.mean(irrelevant_scores)
}
except Exception as e:
logger.error(f"Error testing {model_type} retriever: {e}")
return None
def _test_reranker_model(self, model_path, model_type):
"""Test reranker model with predefined query-document pairs"""
try:
# Load model
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
model = BGERerankerModel.from_pretrained(model_path)
model.to(self.device)
model.eval()
print(f"{model_type} reranker loaded successfully")
correct_rankings = 0
total_pairs = 0
all_scores = []
all_labels = []
with torch.no_grad():
for test_case in self.test_queries:
query = test_case["query"]
relevant_docs = test_case["relevant"]
irrelevant_docs = test_case["irrelevant"]
# Test relevant documents (positive examples)
for doc in relevant_docs:
pair_text = f"{query} [SEP] {doc}"
inputs = tokenizer(
pair_text,
return_tensors='pt',
padding=True,
truncation=True,
max_length=512
).to(self.device)
outputs = model(**inputs)
score = outputs.logits.item()
all_scores.append(score)
all_labels.append(1) # Relevant
# Check if score indicates relevance (positive score)
if score > 0:
correct_rankings += 1
total_pairs += 1
# Test irrelevant documents (negative examples)
for doc in irrelevant_docs:
pair_text = f"{query} [SEP] {doc}"
inputs = tokenizer(
pair_text,
return_tensors='pt',
padding=True,
truncation=True,
max_length=512
).to(self.device)
outputs = model(**inputs)
score = outputs.logits.item()
all_scores.append(score)
all_labels.append(0) # Irrelevant
# Check if score indicates irrelevance (negative score)
if score <= 0:
correct_rankings += 1
total_pairs += 1
return {
'accuracy': correct_rankings / total_pairs if total_pairs > 0 else 0,
'avg_positive_score': np.mean([s for s, l in zip(all_scores, all_labels) if l == 1]),
'avg_negative_score': np.mean([s for s, l in zip(all_scores, all_labels) if l == 0]),
'total_pairs': total_pairs
}
except Exception as e:
logger.error(f"Error testing {model_type} reranker: {e}")
return None
def _compare_retriever_results(self, baseline, finetuned):
"""Compare retriever results"""
if not baseline or not finetuned:
return {}
relevance_improvement = ((finetuned['avg_relevant_score'] - baseline['avg_relevant_score']) /
baseline['avg_relevant_score'] * 100) if baseline['avg_relevant_score'] != 0 else 0
separation_improvement = ((finetuned['score_separation'] - baseline['score_separation']) /
baseline['score_separation'] * 100) if baseline['score_separation'] != 0 else 0
ranking_improvement = ((finetuned['ranking_accuracy'] - baseline['ranking_accuracy']) /
baseline['ranking_accuracy'] * 100) if baseline['ranking_accuracy'] != 0 else 0
return {
'relevance_improvement': relevance_improvement,
'separation_improvement': separation_improvement,
'ranking_improvement': ranking_improvement
}
def _compare_reranker_results(self, baseline, finetuned):
"""Compare reranker results"""
if not baseline or not finetuned:
return {}
accuracy_improvement = ((finetuned['accuracy'] - baseline['accuracy']) /
baseline['accuracy'] * 100) if baseline['accuracy'] != 0 else 0
return {
'accuracy_improvement': accuracy_improvement
}
def validate_pipeline(self, retriever_path, reranker_path):
"""Quick validation of full pipeline"""
print(f"\n🚀 Validating Full Pipeline")
print("=" * 50)
try:
# Load models
retriever_tokenizer = AutoTokenizer.from_pretrained(retriever_path, trust_remote_code=True)
retriever = BGEM3Model.from_pretrained(retriever_path)
retriever.to(self.device)
retriever.eval()
reranker_tokenizer = AutoTokenizer.from_pretrained(reranker_path, trust_remote_code=True)
reranker = BGERerankerModel.from_pretrained(reranker_path)
reranker.to(self.device)
reranker.eval()
print(" ✅ Pipeline models loaded successfully")
total_queries = 0
pipeline_accuracy = 0
with torch.no_grad():
for test_case in self.test_queries:
query = test_case["query"]
relevant_docs = test_case["relevant"]
irrelevant_docs = test_case["irrelevant"]
# Create candidate pool
all_docs = relevant_docs + irrelevant_docs
doc_labels = [1] * len(relevant_docs) + [0] * len(irrelevant_docs)
# Step 1: Retrieval
query_inputs = retriever_tokenizer(
query, return_tensors='pt', padding=True,
truncation=True, max_length=512
).to(self.device)
query_emb = retriever(**query_inputs, return_dense=True)['dense']
# Score all documents with retriever
doc_scores = []
for doc in all_docs:
doc_inputs = retriever_tokenizer(
doc, return_tensors='pt', padding=True,
truncation=True, max_length=512
).to(self.device)
doc_emb = retriever(**doc_inputs, return_dense=True)['dense']
score = torch.cosine_similarity(query_emb, doc_emb, dim=1).item()
doc_scores.append(score)
# Get top-k for reranking (all docs in this case)
retrieval_indices = sorted(range(len(doc_scores)),
key=lambda i: doc_scores[i], reverse=True)
# Step 2: Reranking
rerank_scores = []
rerank_labels = []
for idx in retrieval_indices:
doc = all_docs[idx]
label = doc_labels[idx]
pair_text = f"{query} [SEP] {doc}"
pair_inputs = reranker_tokenizer(
pair_text, return_tensors='pt', padding=True,
truncation=True, max_length=512
).to(self.device)
rerank_output = reranker(**pair_inputs)
rerank_score = rerank_output.logits.item()
rerank_scores.append(rerank_score)
rerank_labels.append(label)
# Check if top-ranked document is relevant
best_idx = np.argmax(rerank_scores)
if rerank_labels[best_idx] == 1:
pipeline_accuracy += 1
total_queries += 1
pipeline_acc = pipeline_accuracy / total_queries if total_queries > 0 else 0
print(f"\n📊 Pipeline Validation Summary:")
print(f" Pipeline Accuracy: {pipeline_acc:.4f}")
print(f" Queries Tested: {total_queries}")
if pipeline_acc > 0.5:
print(" ✅ Pipeline performing well!")
else:
print(" ❌ Pipeline needs improvement!")
return {
'accuracy': pipeline_acc,
'total_queries': total_queries,
'status': 'good' if pipeline_acc > 0.5 else 'needs_improvement'
}
except Exception as e:
logger.error(f"Pipeline validation failed: {e}")
return None
def parse_args():
"""Parse command line arguments"""
parser = argparse.ArgumentParser(
description="Quick validation of BGE fine-tuned models",
formatter_class=argparse.RawDescriptionHelpFormatter
)
parser.add_argument("--retriever_model", type=str, default=None,
help="Path to fine-tuned retriever model")
parser.add_argument("--reranker_model", type=str, default=None,
help="Path to fine-tuned reranker model")
parser.add_argument("--retriever_baseline", type=str, default="BAAI/bge-m3",
help="Path to baseline retriever model")
parser.add_argument("--reranker_baseline", type=str, default="BAAI/bge-reranker-base",
help="Path to baseline reranker model")
parser.add_argument("--test_pipeline", action="store_true",
help="Test full retrieval + reranking pipeline")
parser.add_argument("--output_file", type=str, default=None,
help="Save results to JSON file")
return parser.parse_args()
def main():
"""Main validation function"""
args = parse_args()
if not args.retriever_model and not args.reranker_model:
print("❌ Error: At least one model must be specified")
print(" Use --retriever_model or --reranker_model")
return 1
print("🚀 BGE Quick Validation")
print("=" * 60)
print(f"⏰ Started at: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
validator = QuickValidator()
results = {}
# Validate retriever
if args.retriever_model:
if os.path.exists(args.retriever_model):
retriever_results = validator.validate_retriever(
args.retriever_model, args.retriever_baseline
)
if retriever_results:
results['retriever'] = retriever_results
else:
print(f"❌ Retriever model not found: {args.retriever_model}")
# Validate reranker
if args.reranker_model:
if os.path.exists(args.reranker_model):
reranker_results = validator.validate_reranker(
args.reranker_model, args.reranker_baseline
)
if reranker_results:
results['reranker'] = reranker_results
else:
print(f"❌ Reranker model not found: {args.reranker_model}")
# Validate pipeline
if (args.test_pipeline and args.retriever_model and args.reranker_model and
os.path.exists(args.retriever_model) and os.path.exists(args.reranker_model)):
pipeline_results = validator.validate_pipeline(
args.retriever_model, args.reranker_model
)
if pipeline_results:
results['pipeline'] = pipeline_results
# Print final summary
print(f"\n{'='*60}")
print("🎯 FINAL VALIDATION SUMMARY")
print(f"{'='*60}")
overall_status = "✅ SUCCESS"
issues = []
for model_type, model_results in results.items():
if 'summary' in model_results:
status = model_results['summary'].get('status', 'unknown')
improvement = model_results['summary'].get('improvement_pct', 0)
if status == 'improved':
print(f"{model_type.title()}: Improved by {improvement:.2f}%")
elif status == 'degraded':
print(f"{model_type.title()}: Degraded by {abs(improvement):.2f}%")
issues.append(model_type)
overall_status = "⚠️ ISSUES DETECTED"
elif model_type == 'pipeline' and 'status' in model_results:
if model_results['status'] == 'good':
print(f"✅ Pipeline: Working well (accuracy: {model_results['accuracy']:.3f})")
else:
print(f"❌ Pipeline: Needs improvement (accuracy: {model_results['accuracy']:.3f})")
issues.append('pipeline')
overall_status = "⚠️ ISSUES DETECTED"
print(f"\n{overall_status}")
if issues:
print(f"⚠️ Issues detected in: {', '.join(issues)}")
print("💡 Consider running comprehensive validation for detailed analysis")
# Save results if requested
if args.output_file:
results['timestamp'] = datetime.now().isoformat()
results['overall_status'] = overall_status
results['issues'] = issues
with open(args.output_file, 'w', encoding='utf-8') as f:
json.dump(results, f, indent=2, ensure_ascii=False)
print(f"📊 Results saved to: {args.output_file}")
return 0 if overall_status == "✅ SUCCESS" else 1
if __name__ == "__main__":
exit(main())