init commit
This commit is contained in:
150
scripts/validate_reranker.py
Normal file
150
scripts/validate_reranker.py
Normal file
@@ -0,0 +1,150 @@
|
||||
"""
|
||||
Quick validation script for the fine-tuned BGE Reranker model
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import torch
|
||||
import numpy as np
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
# Add project root to Python path
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from models.bge_reranker import BGERerankerModel
|
||||
|
||||
def test_reranker_scoring(model_path, test_queries=None):
|
||||
"""Test if the reranker can score query-passage pairs"""
|
||||
|
||||
if test_queries is None:
|
||||
# Test query with relevant and irrelevant passages
|
||||
test_queries = [
|
||||
{
|
||||
"query": "什么是深度学习?",
|
||||
"passages": [
|
||||
"深度学习是机器学习的一个子领域", # Relevant (should score high)
|
||||
"机器学习包含多种算法", # Somewhat relevant
|
||||
"今天天气很好", # Irrelevant (should score low)
|
||||
"深度学习使用神经网络进行学习" # Very relevant
|
||||
]
|
||||
}
|
||||
]
|
||||
|
||||
print(f"🧪 Testing reranker from: {model_path}")
|
||||
|
||||
try:
|
||||
# Load tokenizer and model
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
||||
model = BGERerankerModel.from_pretrained(model_path)
|
||||
model.eval()
|
||||
|
||||
print("✅ Reranker model and tokenizer loaded successfully")
|
||||
|
||||
for query_idx, test_case in enumerate(test_queries):
|
||||
query = test_case["query"]
|
||||
passages = test_case["passages"]
|
||||
|
||||
print(f"\n📝 Test Case {query_idx + 1}")
|
||||
print(f"Query: '{query}'")
|
||||
print(f"Passages to rank:")
|
||||
|
||||
scores = []
|
||||
|
||||
for i, passage in enumerate(passages):
|
||||
print(f" {i+1}. '{passage}'")
|
||||
|
||||
# Create query-passage pair
|
||||
sep_token = tokenizer.sep_token or '[SEP]'
|
||||
pair_text = f"{query} {sep_token} {passage}"
|
||||
|
||||
# Tokenize
|
||||
inputs = tokenizer(
|
||||
pair_text,
|
||||
return_tensors='pt',
|
||||
padding=True,
|
||||
truncation=True,
|
||||
max_length=512
|
||||
)
|
||||
|
||||
# Get relevance score
|
||||
with torch.no_grad():
|
||||
outputs = model(**inputs)
|
||||
|
||||
# Extract score (logits)
|
||||
if hasattr(outputs, 'logits'):
|
||||
score = outputs.logits.item()
|
||||
elif isinstance(outputs, torch.Tensor):
|
||||
score = outputs.item()
|
||||
else:
|
||||
score = outputs[0].item()
|
||||
|
||||
scores.append((i, passage, score))
|
||||
|
||||
# Sort by relevance score (highest first)
|
||||
scores.sort(key=lambda x: x[2], reverse=True)
|
||||
|
||||
print(f"\n🏆 Ranking Results:")
|
||||
for rank, (orig_idx, passage, score) in enumerate(scores, 1):
|
||||
print(f" Rank {rank}: Score {score:.4f} - '{passage[:50]}{'...' if len(passage) > 50 else ''}'")
|
||||
|
||||
# Check if most relevant passages ranked highest
|
||||
expected_relevant = ["深度学习是机器学习的一个子领域", "深度学习使用神经网络进行学习"]
|
||||
top_passages = [item[1] for item in scores[:2]]
|
||||
|
||||
relevant_in_top = any(rel in top_passages for rel in expected_relevant)
|
||||
irrelevant_in_top = "今天天气很好" in top_passages
|
||||
|
||||
if relevant_in_top and not irrelevant_in_top:
|
||||
print(f" ✅ Good ranking - relevant passages ranked higher")
|
||||
elif relevant_in_top:
|
||||
print(f" ⚠️ Partial ranking - relevant high but irrelevant also high")
|
||||
else:
|
||||
print(f" ❌ Poor ranking - irrelevant passages ranked too high")
|
||||
|
||||
print(f"\n🎉 Reranker validation successful!")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Reranker validation failed: {e}")
|
||||
return False
|
||||
|
||||
def compare_rerankers(original_path, finetuned_path):
|
||||
"""Compare original vs fine-tuned reranker"""
|
||||
print("🔄 Comparing original vs fine-tuned reranker...")
|
||||
|
||||
test_case = {
|
||||
"query": "什么是深度学习?",
|
||||
"passages": [
|
||||
"深度学习是机器学习的一个子领域",
|
||||
"今天天气很好"
|
||||
]
|
||||
}
|
||||
|
||||
# Test original model
|
||||
print("\n📍 Original Reranker:")
|
||||
orig_success = test_reranker_scoring(original_path, [test_case])
|
||||
|
||||
# Test fine-tuned model
|
||||
print("\n📍 Fine-tuned Reranker:")
|
||||
ft_success = test_reranker_scoring(finetuned_path, [test_case])
|
||||
|
||||
return orig_success and ft_success
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("🚀 BGE Reranker Model Validation")
|
||||
print("=" * 50)
|
||||
|
||||
# Test fine-tuned model
|
||||
finetuned_model_path = "./test_reranker_output/final_model"
|
||||
success = test_reranker_scoring(finetuned_model_path)
|
||||
|
||||
# Optional: Compare with original model
|
||||
original_model_path = "./models/bge-reranker-base"
|
||||
compare_rerankers(original_model_path, finetuned_model_path)
|
||||
|
||||
if success:
|
||||
print("\n✅ Reranker validation: PASSED")
|
||||
print(" The reranker can score query-passage pairs successfully!")
|
||||
else:
|
||||
print("\n❌ Reranker validation: FAILED")
|
||||
print(" The reranker has issues with scoring!")
|
||||
Reference in New Issue
Block a user