""" Quick validation script for the fine-tuned BGE Reranker model """ import os import sys import torch import numpy as np from transformers import AutoTokenizer # Add project root to Python path sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from models.bge_reranker import BGERerankerModel def test_reranker_scoring(model_path, test_queries=None): """Test if the reranker can score query-passage pairs""" if test_queries is None: # Test query with relevant and irrelevant passages test_queries = [ { "query": "什么是深度学习?", "passages": [ "深度学习是机器学习的一个子领域", # Relevant (should score high) "机器学习包含多种算法", # Somewhat relevant "今天天气很好", # Irrelevant (should score low) "深度学习使用神经网络进行学习" # Very relevant ] } ] print(f"🧪 Testing reranker from: {model_path}") try: # Load tokenizer and model tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) model = BGERerankerModel.from_pretrained(model_path) model.eval() print("✅ Reranker model and tokenizer loaded successfully") for query_idx, test_case in enumerate(test_queries): query = test_case["query"] passages = test_case["passages"] print(f"\n📝 Test Case {query_idx + 1}") print(f"Query: '{query}'") print(f"Passages to rank:") scores = [] for i, passage in enumerate(passages): print(f" {i+1}. '{passage}'") # Create query-passage pair sep_token = tokenizer.sep_token or '[SEP]' pair_text = f"{query} {sep_token} {passage}" # Tokenize inputs = tokenizer( pair_text, return_tensors='pt', padding=True, truncation=True, max_length=512 ) # Get relevance score with torch.no_grad(): outputs = model(**inputs) # Extract score (logits) if hasattr(outputs, 'logits'): score = outputs.logits.item() elif isinstance(outputs, torch.Tensor): score = outputs.item() else: score = outputs[0].item() scores.append((i, passage, score)) # Sort by relevance score (highest first) scores.sort(key=lambda x: x[2], reverse=True) print(f"\n🏆 Ranking Results:") for rank, (orig_idx, passage, score) in enumerate(scores, 1): print(f" Rank {rank}: Score {score:.4f} - '{passage[:50]}{'...' if len(passage) > 50 else ''}'") # Check if most relevant passages ranked highest expected_relevant = ["深度学习是机器学习的一个子领域", "深度学习使用神经网络进行学习"] top_passages = [item[1] for item in scores[:2]] relevant_in_top = any(rel in top_passages for rel in expected_relevant) irrelevant_in_top = "今天天气很好" in top_passages if relevant_in_top and not irrelevant_in_top: print(f" ✅ Good ranking - relevant passages ranked higher") elif relevant_in_top: print(f" ⚠️ Partial ranking - relevant high but irrelevant also high") else: print(f" ❌ Poor ranking - irrelevant passages ranked too high") print(f"\n🎉 Reranker validation successful!") return True except Exception as e: print(f"❌ Reranker validation failed: {e}") return False def compare_rerankers(original_path, finetuned_path): """Compare original vs fine-tuned reranker""" print("🔄 Comparing original vs fine-tuned reranker...") test_case = { "query": "什么是深度学习?", "passages": [ "深度学习是机器学习的一个子领域", "今天天气很好" ] } # Test original model print("\n📍 Original Reranker:") orig_success = test_reranker_scoring(original_path, [test_case]) # Test fine-tuned model print("\n📍 Fine-tuned Reranker:") ft_success = test_reranker_scoring(finetuned_path, [test_case]) return orig_success and ft_success if __name__ == "__main__": print("🚀 BGE Reranker Model Validation") print("=" * 50) # Test fine-tuned model finetuned_model_path = "./test_reranker_output/final_model" success = test_reranker_scoring(finetuned_model_path) # Optional: Compare with original model original_model_path = "./models/bge-reranker-base" compare_rerankers(original_model_path, finetuned_model_path) if success: print("\n✅ Reranker validation: PASSED") print(" The reranker can score query-passage pairs successfully!") else: print("\n❌ Reranker validation: FAILED") print(" The reranker has issues with scoring!")