bge_finetune/scripts/validate_reranker.py
2025-07-22 16:55:25 +08:00

151 lines
5.5 KiB
Python

"""
Quick validation script for the fine-tuned BGE Reranker model
"""
import os
import sys
import torch
import numpy as np
from transformers import AutoTokenizer
# Add project root to Python path
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from models.bge_reranker import BGERerankerModel
def test_reranker_scoring(model_path, test_queries=None):
"""Test if the reranker can score query-passage pairs"""
if test_queries is None:
# Test query with relevant and irrelevant passages
test_queries = [
{
"query": "什么是深度学习?",
"passages": [
"深度学习是机器学习的一个子领域", # Relevant (should score high)
"机器学习包含多种算法", # Somewhat relevant
"今天天气很好", # Irrelevant (should score low)
"深度学习使用神经网络进行学习" # Very relevant
]
}
]
print(f"🧪 Testing reranker from: {model_path}")
try:
# Load tokenizer and model
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
model = BGERerankerModel.from_pretrained(model_path)
model.eval()
print("✅ Reranker model and tokenizer loaded successfully")
for query_idx, test_case in enumerate(test_queries):
query = test_case["query"]
passages = test_case["passages"]
print(f"\n📝 Test Case {query_idx + 1}")
print(f"Query: '{query}'")
print(f"Passages to rank:")
scores = []
for i, passage in enumerate(passages):
print(f" {i+1}. '{passage}'")
# Create query-passage pair
sep_token = tokenizer.sep_token or '[SEP]'
pair_text = f"{query} {sep_token} {passage}"
# Tokenize
inputs = tokenizer(
pair_text,
return_tensors='pt',
padding=True,
truncation=True,
max_length=512
)
# Get relevance score
with torch.no_grad():
outputs = model(**inputs)
# Extract score (logits)
if hasattr(outputs, 'logits'):
score = outputs.logits.item()
elif isinstance(outputs, torch.Tensor):
score = outputs.item()
else:
score = outputs[0].item()
scores.append((i, passage, score))
# Sort by relevance score (highest first)
scores.sort(key=lambda x: x[2], reverse=True)
print(f"\n🏆 Ranking Results:")
for rank, (orig_idx, passage, score) in enumerate(scores, 1):
print(f" Rank {rank}: Score {score:.4f} - '{passage[:50]}{'...' if len(passage) > 50 else ''}'")
# Check if most relevant passages ranked highest
expected_relevant = ["深度学习是机器学习的一个子领域", "深度学习使用神经网络进行学习"]
top_passages = [item[1] for item in scores[:2]]
relevant_in_top = any(rel in top_passages for rel in expected_relevant)
irrelevant_in_top = "今天天气很好" in top_passages
if relevant_in_top and not irrelevant_in_top:
print(f" ✅ Good ranking - relevant passages ranked higher")
elif relevant_in_top:
print(f" ⚠️ Partial ranking - relevant high but irrelevant also high")
else:
print(f" ❌ Poor ranking - irrelevant passages ranked too high")
print(f"\n🎉 Reranker validation successful!")
return True
except Exception as e:
print(f"❌ Reranker validation failed: {e}")
return False
def compare_rerankers(original_path, finetuned_path):
"""Compare original vs fine-tuned reranker"""
print("🔄 Comparing original vs fine-tuned reranker...")
test_case = {
"query": "什么是深度学习?",
"passages": [
"深度学习是机器学习的一个子领域",
"今天天气很好"
]
}
# Test original model
print("\n📍 Original Reranker:")
orig_success = test_reranker_scoring(original_path, [test_case])
# Test fine-tuned model
print("\n📍 Fine-tuned Reranker:")
ft_success = test_reranker_scoring(finetuned_path, [test_case])
return orig_success and ft_success
if __name__ == "__main__":
print("🚀 BGE Reranker Model Validation")
print("=" * 50)
# Test fine-tuned model
finetuned_model_path = "./test_reranker_output/final_model"
success = test_reranker_scoring(finetuned_model_path)
# Optional: Compare with original model
original_model_path = "./models/bge-reranker-base"
compare_rerankers(original_model_path, finetuned_model_path)
if success:
print("\n✅ Reranker validation: PASSED")
print(" The reranker can score query-passage pairs successfully!")
else:
print("\n❌ Reranker validation: FAILED")
print(" The reranker has issues with scoring!")