init commit

This commit is contained in:
ldy
2025-07-22 16:55:25 +08:00
commit 36003b83e2
67 changed files with 76613 additions and 0 deletions

624
scripts/quick_validation.py Normal file
View File

@@ -0,0 +1,624 @@
#!/usr/bin/env python3
"""
Quick Validation Script for BGE Fine-tuned Models
This script provides fast validation to quickly check if fine-tuned models
are working correctly and performing better than baselines.
Usage:
# Quick test both models
python scripts/quick_validation.py \\
--retriever_model ./output/bge-m3-enhanced/final_model \\
--reranker_model ./output/bge-reranker/final_model
# Test only retriever
python scripts/quick_validation.py \\
--retriever_model ./output/bge-m3-enhanced/final_model
"""
import os
import sys
import json
import argparse
import logging
import torch
import numpy as np
from transformers import AutoTokenizer
from datetime import datetime
from pathlib import Path
# Add project root to path
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from models.bge_m3 import BGEM3Model
from models.bge_reranker import BGERerankerModel
from utils.config_loader import get_config
logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s')
logger = logging.getLogger(__name__)
class QuickValidator:
"""Quick validation for BGE models"""
def __init__(self):
self.device = self._get_device()
self.test_queries = self._get_test_queries()
def _get_device(self):
"""Get optimal device"""
try:
config = get_config()
device_type = config.get('hardware', {}).get('device', 'cuda')
if device_type == 'npu' and torch.cuda.is_available():
try:
import torch_npu
return torch.device("npu:0")
except ImportError:
pass
if torch.cuda.is_available():
return torch.device("cuda")
except:
pass
return torch.device("cpu")
def _get_test_queries(self):
"""Get standard test queries for validation"""
return [
{
"query": "什么是深度学习?",
"relevant": [
"深度学习是机器学习的一个子领域,使用多层神经网络来学习数据表示",
"深度学习通过神经网络的多个层次来学习数据的抽象特征"
],
"irrelevant": [
"今天天气很好,阳光明媚",
"我喜欢吃苹果和香蕉"
]
},
{
"query": "如何制作红烧肉?",
"relevant": [
"红烧肉的制作需要五花肉、生抽、老抽、冰糖等材料,先炒糖色再炖煮",
"制作红烧肉的关键是掌握火候和糖色的调制"
],
"irrelevant": [
"Python是一种编程语言",
"机器学习需要大量的数据"
]
},
{
"query": "Python编程入门",
"relevant": [
"Python是一种易学易用的编程语言适合初学者入门",
"学习Python需要掌握基本语法、数据类型和控制结构"
],
"irrelevant": [
"红烧肉是一道经典的中式菜肴",
"深度学习使用神经网络进行训练"
]
}
]
def validate_retriever(self, model_path, baseline_path="BAAI/bge-m3"):
"""Quick validation of retriever model"""
print(f"\n🔍 Validating Retriever Model")
print("=" * 50)
results = {
'model_path': model_path,
'baseline_path': baseline_path,
'test_results': {},
'summary': {}
}
try:
# Test fine-tuned model
print("Testing fine-tuned model...")
ft_scores = self._test_retriever_model(model_path, "Fine-tuned")
# Test baseline model
print("Testing baseline model...")
baseline_scores = self._test_retriever_model(baseline_path, "Baseline")
# Compare results
comparison = self._compare_retriever_results(baseline_scores, ft_scores)
results['test_results'] = {
'finetuned': ft_scores,
'baseline': baseline_scores,
'comparison': comparison
}
# Print summary
print(f"\n📊 Retriever Validation Summary:")
print(f" Baseline Avg Relevance Score: {baseline_scores['avg_relevant_score']:.4f}")
print(f" Fine-tuned Avg Relevance Score: {ft_scores['avg_relevant_score']:.4f}")
print(f" Improvement: {comparison['relevance_improvement']:.2f}%")
if comparison['relevance_improvement'] > 0:
print(" ✅ Fine-tuned model shows improvement!")
else:
print(" ❌ Fine-tuned model shows degradation!")
results['summary'] = {
'status': 'improved' if comparison['relevance_improvement'] > 0 else 'degraded',
'improvement_pct': comparison['relevance_improvement']
}
return results
except Exception as e:
logger.error(f"Retriever validation failed: {e}")
return None
def validate_reranker(self, model_path, baseline_path="BAAI/bge-reranker-base"):
"""Quick validation of reranker model"""
print(f"\n🏆 Validating Reranker Model")
print("=" * 50)
results = {
'model_path': model_path,
'baseline_path': baseline_path,
'test_results': {},
'summary': {}
}
try:
# Test fine-tuned model
print("Testing fine-tuned reranker...")
ft_scores = self._test_reranker_model(model_path, "Fine-tuned")
# Test baseline model
print("Testing baseline reranker...")
baseline_scores = self._test_reranker_model(baseline_path, "Baseline")
# Compare results
comparison = self._compare_reranker_results(baseline_scores, ft_scores)
results['test_results'] = {
'finetuned': ft_scores,
'baseline': baseline_scores,
'comparison': comparison
}
# Print summary
print(f"\n📊 Reranker Validation Summary:")
print(f" Baseline Accuracy: {baseline_scores['accuracy']:.4f}")
print(f" Fine-tuned Accuracy: {ft_scores['accuracy']:.4f}")
print(f" Improvement: {comparison['accuracy_improvement']:.2f}%")
if comparison['accuracy_improvement'] > 0:
print(" ✅ Fine-tuned reranker shows improvement!")
else:
print(" ❌ Fine-tuned reranker shows degradation!")
results['summary'] = {
'status': 'improved' if comparison['accuracy_improvement'] > 0 else 'degraded',
'improvement_pct': comparison['accuracy_improvement']
}
return results
except Exception as e:
logger.error(f"Reranker validation failed: {e}")
return None
def _test_retriever_model(self, model_path, model_type):
"""Test retriever model with predefined queries"""
try:
# Load model
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
model = BGEM3Model.from_pretrained(model_path)
model.to(self.device)
model.eval()
print(f"{model_type} model loaded successfully")
relevant_scores = []
irrelevant_scores = []
ranking_accuracy = []
with torch.no_grad():
for test_case in self.test_queries:
query = test_case["query"]
relevant_docs = test_case["relevant"]
irrelevant_docs = test_case["irrelevant"]
# Encode query
query_inputs = tokenizer(
query,
return_tensors='pt',
padding=True,
truncation=True,
max_length=512
).to(self.device)
query_outputs = model(**query_inputs, return_dense=True)
query_emb = query_outputs['dense']
# Test relevant documents
for doc in relevant_docs:
doc_inputs = tokenizer(
doc,
return_tensors='pt',
padding=True,
truncation=True,
max_length=512
).to(self.device)
doc_outputs = model(**doc_inputs, return_dense=True)
doc_emb = doc_outputs['dense']
# Compute similarity
similarity = torch.cosine_similarity(query_emb, doc_emb, dim=1).item()
relevant_scores.append(similarity)
# Test irrelevant documents
for doc in irrelevant_docs:
doc_inputs = tokenizer(
doc,
return_tensors='pt',
padding=True,
truncation=True,
max_length=512
).to(self.device)
doc_outputs = model(**doc_inputs, return_dense=True)
doc_emb = doc_outputs['dense']
similarity = torch.cosine_similarity(query_emb, doc_emb, dim=1).item()
irrelevant_scores.append(similarity)
# Check ranking accuracy (relevant should score higher than irrelevant)
query_relevant_avg = np.mean([torch.cosine_similarity(query_emb,
model(**tokenizer(doc, return_tensors='pt', padding=True, truncation=True, max_length=512).to(self.device), return_dense=True)['dense'], dim=1).item()
for doc in relevant_docs])
query_irrelevant_avg = np.mean([torch.cosine_similarity(query_emb,
model(**tokenizer(doc, return_tensors='pt', padding=True, truncation=True, max_length=512).to(self.device), return_dense=True)['dense'], dim=1).item()
for doc in irrelevant_docs])
ranking_accuracy.append(1.0 if query_relevant_avg > query_irrelevant_avg else 0.0)
return {
'avg_relevant_score': np.mean(relevant_scores),
'avg_irrelevant_score': np.mean(irrelevant_scores),
'ranking_accuracy': np.mean(ranking_accuracy),
'score_separation': np.mean(relevant_scores) - np.mean(irrelevant_scores)
}
except Exception as e:
logger.error(f"Error testing {model_type} retriever: {e}")
return None
def _test_reranker_model(self, model_path, model_type):
"""Test reranker model with predefined query-document pairs"""
try:
# Load model
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
model = BGERerankerModel.from_pretrained(model_path)
model.to(self.device)
model.eval()
print(f"{model_type} reranker loaded successfully")
correct_rankings = 0
total_pairs = 0
all_scores = []
all_labels = []
with torch.no_grad():
for test_case in self.test_queries:
query = test_case["query"]
relevant_docs = test_case["relevant"]
irrelevant_docs = test_case["irrelevant"]
# Test relevant documents (positive examples)
for doc in relevant_docs:
pair_text = f"{query} [SEP] {doc}"
inputs = tokenizer(
pair_text,
return_tensors='pt',
padding=True,
truncation=True,
max_length=512
).to(self.device)
outputs = model(**inputs)
score = outputs.logits.item()
all_scores.append(score)
all_labels.append(1) # Relevant
# Check if score indicates relevance (positive score)
if score > 0:
correct_rankings += 1
total_pairs += 1
# Test irrelevant documents (negative examples)
for doc in irrelevant_docs:
pair_text = f"{query} [SEP] {doc}"
inputs = tokenizer(
pair_text,
return_tensors='pt',
padding=True,
truncation=True,
max_length=512
).to(self.device)
outputs = model(**inputs)
score = outputs.logits.item()
all_scores.append(score)
all_labels.append(0) # Irrelevant
# Check if score indicates irrelevance (negative score)
if score <= 0:
correct_rankings += 1
total_pairs += 1
return {
'accuracy': correct_rankings / total_pairs if total_pairs > 0 else 0,
'avg_positive_score': np.mean([s for s, l in zip(all_scores, all_labels) if l == 1]),
'avg_negative_score': np.mean([s for s, l in zip(all_scores, all_labels) if l == 0]),
'total_pairs': total_pairs
}
except Exception as e:
logger.error(f"Error testing {model_type} reranker: {e}")
return None
def _compare_retriever_results(self, baseline, finetuned):
"""Compare retriever results"""
if not baseline or not finetuned:
return {}
relevance_improvement = ((finetuned['avg_relevant_score'] - baseline['avg_relevant_score']) /
baseline['avg_relevant_score'] * 100) if baseline['avg_relevant_score'] != 0 else 0
separation_improvement = ((finetuned['score_separation'] - baseline['score_separation']) /
baseline['score_separation'] * 100) if baseline['score_separation'] != 0 else 0
ranking_improvement = ((finetuned['ranking_accuracy'] - baseline['ranking_accuracy']) /
baseline['ranking_accuracy'] * 100) if baseline['ranking_accuracy'] != 0 else 0
return {
'relevance_improvement': relevance_improvement,
'separation_improvement': separation_improvement,
'ranking_improvement': ranking_improvement
}
def _compare_reranker_results(self, baseline, finetuned):
"""Compare reranker results"""
if not baseline or not finetuned:
return {}
accuracy_improvement = ((finetuned['accuracy'] - baseline['accuracy']) /
baseline['accuracy'] * 100) if baseline['accuracy'] != 0 else 0
return {
'accuracy_improvement': accuracy_improvement
}
def validate_pipeline(self, retriever_path, reranker_path):
"""Quick validation of full pipeline"""
print(f"\n🚀 Validating Full Pipeline")
print("=" * 50)
try:
# Load models
retriever_tokenizer = AutoTokenizer.from_pretrained(retriever_path, trust_remote_code=True)
retriever = BGEM3Model.from_pretrained(retriever_path)
retriever.to(self.device)
retriever.eval()
reranker_tokenizer = AutoTokenizer.from_pretrained(reranker_path, trust_remote_code=True)
reranker = BGERerankerModel.from_pretrained(reranker_path)
reranker.to(self.device)
reranker.eval()
print(" ✅ Pipeline models loaded successfully")
total_queries = 0
pipeline_accuracy = 0
with torch.no_grad():
for test_case in self.test_queries:
query = test_case["query"]
relevant_docs = test_case["relevant"]
irrelevant_docs = test_case["irrelevant"]
# Create candidate pool
all_docs = relevant_docs + irrelevant_docs
doc_labels = [1] * len(relevant_docs) + [0] * len(irrelevant_docs)
# Step 1: Retrieval
query_inputs = retriever_tokenizer(
query, return_tensors='pt', padding=True,
truncation=True, max_length=512
).to(self.device)
query_emb = retriever(**query_inputs, return_dense=True)['dense']
# Score all documents with retriever
doc_scores = []
for doc in all_docs:
doc_inputs = retriever_tokenizer(
doc, return_tensors='pt', padding=True,
truncation=True, max_length=512
).to(self.device)
doc_emb = retriever(**doc_inputs, return_dense=True)['dense']
score = torch.cosine_similarity(query_emb, doc_emb, dim=1).item()
doc_scores.append(score)
# Get top-k for reranking (all docs in this case)
retrieval_indices = sorted(range(len(doc_scores)),
key=lambda i: doc_scores[i], reverse=True)
# Step 2: Reranking
rerank_scores = []
rerank_labels = []
for idx in retrieval_indices:
doc = all_docs[idx]
label = doc_labels[idx]
pair_text = f"{query} [SEP] {doc}"
pair_inputs = reranker_tokenizer(
pair_text, return_tensors='pt', padding=True,
truncation=True, max_length=512
).to(self.device)
rerank_output = reranker(**pair_inputs)
rerank_score = rerank_output.logits.item()
rerank_scores.append(rerank_score)
rerank_labels.append(label)
# Check if top-ranked document is relevant
best_idx = np.argmax(rerank_scores)
if rerank_labels[best_idx] == 1:
pipeline_accuracy += 1
total_queries += 1
pipeline_acc = pipeline_accuracy / total_queries if total_queries > 0 else 0
print(f"\n📊 Pipeline Validation Summary:")
print(f" Pipeline Accuracy: {pipeline_acc:.4f}")
print(f" Queries Tested: {total_queries}")
if pipeline_acc > 0.5:
print(" ✅ Pipeline performing well!")
else:
print(" ❌ Pipeline needs improvement!")
return {
'accuracy': pipeline_acc,
'total_queries': total_queries,
'status': 'good' if pipeline_acc > 0.5 else 'needs_improvement'
}
except Exception as e:
logger.error(f"Pipeline validation failed: {e}")
return None
def parse_args():
"""Parse command line arguments"""
parser = argparse.ArgumentParser(
description="Quick validation of BGE fine-tuned models",
formatter_class=argparse.RawDescriptionHelpFormatter
)
parser.add_argument("--retriever_model", type=str, default=None,
help="Path to fine-tuned retriever model")
parser.add_argument("--reranker_model", type=str, default=None,
help="Path to fine-tuned reranker model")
parser.add_argument("--retriever_baseline", type=str, default="BAAI/bge-m3",
help="Path to baseline retriever model")
parser.add_argument("--reranker_baseline", type=str, default="BAAI/bge-reranker-base",
help="Path to baseline reranker model")
parser.add_argument("--test_pipeline", action="store_true",
help="Test full retrieval + reranking pipeline")
parser.add_argument("--output_file", type=str, default=None,
help="Save results to JSON file")
return parser.parse_args()
def main():
"""Main validation function"""
args = parse_args()
if not args.retriever_model and not args.reranker_model:
print("❌ Error: At least one model must be specified")
print(" Use --retriever_model or --reranker_model")
return 1
print("🚀 BGE Quick Validation")
print("=" * 60)
print(f"⏰ Started at: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
validator = QuickValidator()
results = {}
# Validate retriever
if args.retriever_model:
if os.path.exists(args.retriever_model):
retriever_results = validator.validate_retriever(
args.retriever_model, args.retriever_baseline
)
if retriever_results:
results['retriever'] = retriever_results
else:
print(f"❌ Retriever model not found: {args.retriever_model}")
# Validate reranker
if args.reranker_model:
if os.path.exists(args.reranker_model):
reranker_results = validator.validate_reranker(
args.reranker_model, args.reranker_baseline
)
if reranker_results:
results['reranker'] = reranker_results
else:
print(f"❌ Reranker model not found: {args.reranker_model}")
# Validate pipeline
if (args.test_pipeline and args.retriever_model and args.reranker_model and
os.path.exists(args.retriever_model) and os.path.exists(args.reranker_model)):
pipeline_results = validator.validate_pipeline(
args.retriever_model, args.reranker_model
)
if pipeline_results:
results['pipeline'] = pipeline_results
# Print final summary
print(f"\n{'='*60}")
print("🎯 FINAL VALIDATION SUMMARY")
print(f"{'='*60}")
overall_status = "✅ SUCCESS"
issues = []
for model_type, model_results in results.items():
if 'summary' in model_results:
status = model_results['summary'].get('status', 'unknown')
improvement = model_results['summary'].get('improvement_pct', 0)
if status == 'improved':
print(f"{model_type.title()}: Improved by {improvement:.2f}%")
elif status == 'degraded':
print(f"{model_type.title()}: Degraded by {abs(improvement):.2f}%")
issues.append(model_type)
overall_status = "⚠️ ISSUES DETECTED"
elif model_type == 'pipeline' and 'status' in model_results:
if model_results['status'] == 'good':
print(f"✅ Pipeline: Working well (accuracy: {model_results['accuracy']:.3f})")
else:
print(f"❌ Pipeline: Needs improvement (accuracy: {model_results['accuracy']:.3f})")
issues.append('pipeline')
overall_status = "⚠️ ISSUES DETECTED"
print(f"\n{overall_status}")
if issues:
print(f"⚠️ Issues detected in: {', '.join(issues)}")
print("💡 Consider running comprehensive validation for detailed analysis")
# Save results if requested
if args.output_file:
results['timestamp'] = datetime.now().isoformat()
results['overall_status'] = overall_status
results['issues'] = issues
with open(args.output_file, 'w', encoding='utf-8') as f:
json.dump(results, f, indent=2, ensure_ascii=False)
print(f"📊 Results saved to: {args.output_file}")
return 0 if overall_status == "✅ SUCCESS" else 1
if __name__ == "__main__":
exit(main())