151 lines
5.5 KiB
Python
151 lines
5.5 KiB
Python
"""
|
|
Quick validation script for the fine-tuned BGE Reranker model
|
|
"""
|
|
|
|
import os
|
|
import sys
|
|
import torch
|
|
import numpy as np
|
|
from transformers import AutoTokenizer
|
|
|
|
# Add project root to Python path
|
|
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
|
|
|
from models.bge_reranker import BGERerankerModel
|
|
|
|
def test_reranker_scoring(model_path, test_queries=None):
|
|
"""Test if the reranker can score query-passage pairs"""
|
|
|
|
if test_queries is None:
|
|
# Test query with relevant and irrelevant passages
|
|
test_queries = [
|
|
{
|
|
"query": "什么是深度学习?",
|
|
"passages": [
|
|
"深度学习是机器学习的一个子领域", # Relevant (should score high)
|
|
"机器学习包含多种算法", # Somewhat relevant
|
|
"今天天气很好", # Irrelevant (should score low)
|
|
"深度学习使用神经网络进行学习" # Very relevant
|
|
]
|
|
}
|
|
]
|
|
|
|
print(f"🧪 Testing reranker from: {model_path}")
|
|
|
|
try:
|
|
# Load tokenizer and model
|
|
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
|
model = BGERerankerModel.from_pretrained(model_path)
|
|
model.eval()
|
|
|
|
print("✅ Reranker model and tokenizer loaded successfully")
|
|
|
|
for query_idx, test_case in enumerate(test_queries):
|
|
query = test_case["query"]
|
|
passages = test_case["passages"]
|
|
|
|
print(f"\n📝 Test Case {query_idx + 1}")
|
|
print(f"Query: '{query}'")
|
|
print(f"Passages to rank:")
|
|
|
|
scores = []
|
|
|
|
for i, passage in enumerate(passages):
|
|
print(f" {i+1}. '{passage}'")
|
|
|
|
# Create query-passage pair
|
|
sep_token = tokenizer.sep_token or '[SEP]'
|
|
pair_text = f"{query} {sep_token} {passage}"
|
|
|
|
# Tokenize
|
|
inputs = tokenizer(
|
|
pair_text,
|
|
return_tensors='pt',
|
|
padding=True,
|
|
truncation=True,
|
|
max_length=512
|
|
)
|
|
|
|
# Get relevance score
|
|
with torch.no_grad():
|
|
outputs = model(**inputs)
|
|
|
|
# Extract score (logits)
|
|
if hasattr(outputs, 'logits'):
|
|
score = outputs.logits.item()
|
|
elif isinstance(outputs, torch.Tensor):
|
|
score = outputs.item()
|
|
else:
|
|
score = outputs[0].item()
|
|
|
|
scores.append((i, passage, score))
|
|
|
|
# Sort by relevance score (highest first)
|
|
scores.sort(key=lambda x: x[2], reverse=True)
|
|
|
|
print(f"\n🏆 Ranking Results:")
|
|
for rank, (orig_idx, passage, score) in enumerate(scores, 1):
|
|
print(f" Rank {rank}: Score {score:.4f} - '{passage[:50]}{'...' if len(passage) > 50 else ''}'")
|
|
|
|
# Check if most relevant passages ranked highest
|
|
expected_relevant = ["深度学习是机器学习的一个子领域", "深度学习使用神经网络进行学习"]
|
|
top_passages = [item[1] for item in scores[:2]]
|
|
|
|
relevant_in_top = any(rel in top_passages for rel in expected_relevant)
|
|
irrelevant_in_top = "今天天气很好" in top_passages
|
|
|
|
if relevant_in_top and not irrelevant_in_top:
|
|
print(f" ✅ Good ranking - relevant passages ranked higher")
|
|
elif relevant_in_top:
|
|
print(f" ⚠️ Partial ranking - relevant high but irrelevant also high")
|
|
else:
|
|
print(f" ❌ Poor ranking - irrelevant passages ranked too high")
|
|
|
|
print(f"\n🎉 Reranker validation successful!")
|
|
return True
|
|
|
|
except Exception as e:
|
|
print(f"❌ Reranker validation failed: {e}")
|
|
return False
|
|
|
|
def compare_rerankers(original_path, finetuned_path):
|
|
"""Compare original vs fine-tuned reranker"""
|
|
print("🔄 Comparing original vs fine-tuned reranker...")
|
|
|
|
test_case = {
|
|
"query": "什么是深度学习?",
|
|
"passages": [
|
|
"深度学习是机器学习的一个子领域",
|
|
"今天天气很好"
|
|
]
|
|
}
|
|
|
|
# Test original model
|
|
print("\n📍 Original Reranker:")
|
|
orig_success = test_reranker_scoring(original_path, [test_case])
|
|
|
|
# Test fine-tuned model
|
|
print("\n📍 Fine-tuned Reranker:")
|
|
ft_success = test_reranker_scoring(finetuned_path, [test_case])
|
|
|
|
return orig_success and ft_success
|
|
|
|
if __name__ == "__main__":
|
|
print("🚀 BGE Reranker Model Validation")
|
|
print("=" * 50)
|
|
|
|
# Test fine-tuned model
|
|
finetuned_model_path = "./test_reranker_output/final_model"
|
|
success = test_reranker_scoring(finetuned_model_path)
|
|
|
|
# Optional: Compare with original model
|
|
original_model_path = "./models/bge-reranker-base"
|
|
compare_rerankers(original_model_path, finetuned_model_path)
|
|
|
|
if success:
|
|
print("\n✅ Reranker validation: PASSED")
|
|
print(" The reranker can score query-passage pairs successfully!")
|
|
else:
|
|
print("\n❌ Reranker validation: FAILED")
|
|
print(" The reranker has issues with scoring!")
|