#!/usr/bin/env python3 """ Quick Validation Script for BGE Fine-tuned Models This script provides fast validation to quickly check if fine-tuned models are working correctly and performing better than baselines. Usage: # Quick test both models python scripts/quick_validation.py \\ --retriever_model ./output/bge-m3-enhanced/final_model \\ --reranker_model ./output/bge-reranker/final_model # Test only retriever python scripts/quick_validation.py \\ --retriever_model ./output/bge-m3-enhanced/final_model """ import os import sys import json import argparse import logging import torch import numpy as np from transformers import AutoTokenizer from datetime import datetime from pathlib import Path # Add project root to path sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from models.bge_m3 import BGEM3Model from models.bge_reranker import BGERerankerModel from utils.config_loader import get_config logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s') logger = logging.getLogger(__name__) class QuickValidator: """Quick validation for BGE models""" def __init__(self): self.device = self._get_device() self.test_queries = self._get_test_queries() def _get_device(self): """Get optimal device""" try: config = get_config() device_type = config.get('hardware', {}).get('device', 'cuda') if device_type == 'npu' and torch.cuda.is_available(): try: import torch_npu return torch.device("npu:0") except ImportError: pass if torch.cuda.is_available(): return torch.device("cuda") except: pass return torch.device("cpu") def _get_test_queries(self): """Get standard test queries for validation""" return [ { "query": "什么是深度学习?", "relevant": [ "深度学习是机器学习的一个子领域,使用多层神经网络来学习数据表示", "深度学习通过神经网络的多个层次来学习数据的抽象特征" ], "irrelevant": [ "今天天气很好,阳光明媚", "我喜欢吃苹果和香蕉" ] }, { "query": "如何制作红烧肉?", "relevant": [ "红烧肉的制作需要五花肉、生抽、老抽、冰糖等材料,先炒糖色再炖煮", "制作红烧肉的关键是掌握火候和糖色的调制" ], "irrelevant": [ "Python是一种编程语言", "机器学习需要大量的数据" ] }, { "query": "Python编程入门", "relevant": [ "Python是一种易学易用的编程语言,适合初学者入门", "学习Python需要掌握基本语法、数据类型和控制结构" ], "irrelevant": [ "红烧肉是一道经典的中式菜肴", "深度学习使用神经网络进行训练" ] } ] def validate_retriever(self, model_path, baseline_path="BAAI/bge-m3"): """Quick validation of retriever model""" print(f"\n🔍 Validating Retriever Model") print("=" * 50) results = { 'model_path': model_path, 'baseline_path': baseline_path, 'test_results': {}, 'summary': {} } try: # Test fine-tuned model print("Testing fine-tuned model...") ft_scores = self._test_retriever_model(model_path, "Fine-tuned") # Test baseline model print("Testing baseline model...") baseline_scores = self._test_retriever_model(baseline_path, "Baseline") # Compare results comparison = self._compare_retriever_results(baseline_scores, ft_scores) results['test_results'] = { 'finetuned': ft_scores, 'baseline': baseline_scores, 'comparison': comparison } # Print summary print(f"\n📊 Retriever Validation Summary:") print(f" Baseline Avg Relevance Score: {baseline_scores['avg_relevant_score']:.4f}") print(f" Fine-tuned Avg Relevance Score: {ft_scores['avg_relevant_score']:.4f}") print(f" Improvement: {comparison['relevance_improvement']:.2f}%") if comparison['relevance_improvement'] > 0: print(" ✅ Fine-tuned model shows improvement!") else: print(" ❌ Fine-tuned model shows degradation!") results['summary'] = { 'status': 'improved' if comparison['relevance_improvement'] > 0 else 'degraded', 'improvement_pct': comparison['relevance_improvement'] } return results except Exception as e: logger.error(f"Retriever validation failed: {e}") return None def validate_reranker(self, model_path, baseline_path="BAAI/bge-reranker-base"): """Quick validation of reranker model""" print(f"\n🏆 Validating Reranker Model") print("=" * 50) results = { 'model_path': model_path, 'baseline_path': baseline_path, 'test_results': {}, 'summary': {} } try: # Test fine-tuned model print("Testing fine-tuned reranker...") ft_scores = self._test_reranker_model(model_path, "Fine-tuned") # Test baseline model print("Testing baseline reranker...") baseline_scores = self._test_reranker_model(baseline_path, "Baseline") # Compare results comparison = self._compare_reranker_results(baseline_scores, ft_scores) results['test_results'] = { 'finetuned': ft_scores, 'baseline': baseline_scores, 'comparison': comparison } # Print summary print(f"\n📊 Reranker Validation Summary:") print(f" Baseline Accuracy: {baseline_scores['accuracy']:.4f}") print(f" Fine-tuned Accuracy: {ft_scores['accuracy']:.4f}") print(f" Improvement: {comparison['accuracy_improvement']:.2f}%") if comparison['accuracy_improvement'] > 0: print(" ✅ Fine-tuned reranker shows improvement!") else: print(" ❌ Fine-tuned reranker shows degradation!") results['summary'] = { 'status': 'improved' if comparison['accuracy_improvement'] > 0 else 'degraded', 'improvement_pct': comparison['accuracy_improvement'] } return results except Exception as e: logger.error(f"Reranker validation failed: {e}") return None def _test_retriever_model(self, model_path, model_type): """Test retriever model with predefined queries""" try: # Load model tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) model = BGEM3Model.from_pretrained(model_path) model.to(self.device) model.eval() print(f" ✅ {model_type} model loaded successfully") relevant_scores = [] irrelevant_scores = [] ranking_accuracy = [] with torch.no_grad(): for test_case in self.test_queries: query = test_case["query"] relevant_docs = test_case["relevant"] irrelevant_docs = test_case["irrelevant"] # Encode query query_inputs = tokenizer( query, return_tensors='pt', padding=True, truncation=True, max_length=512 ).to(self.device) query_outputs = model(**query_inputs, return_dense=True) query_emb = query_outputs['dense'] # Test relevant documents for doc in relevant_docs: doc_inputs = tokenizer( doc, return_tensors='pt', padding=True, truncation=True, max_length=512 ).to(self.device) doc_outputs = model(**doc_inputs, return_dense=True) doc_emb = doc_outputs['dense'] # Compute similarity similarity = torch.cosine_similarity(query_emb, doc_emb, dim=1).item() relevant_scores.append(similarity) # Test irrelevant documents for doc in irrelevant_docs: doc_inputs = tokenizer( doc, return_tensors='pt', padding=True, truncation=True, max_length=512 ).to(self.device) doc_outputs = model(**doc_inputs, return_dense=True) doc_emb = doc_outputs['dense'] similarity = torch.cosine_similarity(query_emb, doc_emb, dim=1).item() irrelevant_scores.append(similarity) # Check ranking accuracy (relevant should score higher than irrelevant) query_relevant_avg = np.mean([torch.cosine_similarity(query_emb, model(**tokenizer(doc, return_tensors='pt', padding=True, truncation=True, max_length=512).to(self.device), return_dense=True)['dense'], dim=1).item() for doc in relevant_docs]) query_irrelevant_avg = np.mean([torch.cosine_similarity(query_emb, model(**tokenizer(doc, return_tensors='pt', padding=True, truncation=True, max_length=512).to(self.device), return_dense=True)['dense'], dim=1).item() for doc in irrelevant_docs]) ranking_accuracy.append(1.0 if query_relevant_avg > query_irrelevant_avg else 0.0) return { 'avg_relevant_score': np.mean(relevant_scores), 'avg_irrelevant_score': np.mean(irrelevant_scores), 'ranking_accuracy': np.mean(ranking_accuracy), 'score_separation': np.mean(relevant_scores) - np.mean(irrelevant_scores) } except Exception as e: logger.error(f"Error testing {model_type} retriever: {e}") return None def _test_reranker_model(self, model_path, model_type): """Test reranker model with predefined query-document pairs""" try: # Load model tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) model = BGERerankerModel.from_pretrained(model_path) model.to(self.device) model.eval() print(f" ✅ {model_type} reranker loaded successfully") correct_rankings = 0 total_pairs = 0 all_scores = [] all_labels = [] with torch.no_grad(): for test_case in self.test_queries: query = test_case["query"] relevant_docs = test_case["relevant"] irrelevant_docs = test_case["irrelevant"] # Test relevant documents (positive examples) for doc in relevant_docs: pair_text = f"{query} [SEP] {doc}" inputs = tokenizer( pair_text, return_tensors='pt', padding=True, truncation=True, max_length=512 ).to(self.device) outputs = model(**inputs) score = outputs.logits.item() all_scores.append(score) all_labels.append(1) # Relevant # Check if score indicates relevance (positive score) if score > 0: correct_rankings += 1 total_pairs += 1 # Test irrelevant documents (negative examples) for doc in irrelevant_docs: pair_text = f"{query} [SEP] {doc}" inputs = tokenizer( pair_text, return_tensors='pt', padding=True, truncation=True, max_length=512 ).to(self.device) outputs = model(**inputs) score = outputs.logits.item() all_scores.append(score) all_labels.append(0) # Irrelevant # Check if score indicates irrelevance (negative score) if score <= 0: correct_rankings += 1 total_pairs += 1 return { 'accuracy': correct_rankings / total_pairs if total_pairs > 0 else 0, 'avg_positive_score': np.mean([s for s, l in zip(all_scores, all_labels) if l == 1]), 'avg_negative_score': np.mean([s for s, l in zip(all_scores, all_labels) if l == 0]), 'total_pairs': total_pairs } except Exception as e: logger.error(f"Error testing {model_type} reranker: {e}") return None def _compare_retriever_results(self, baseline, finetuned): """Compare retriever results""" if not baseline or not finetuned: return {} relevance_improvement = ((finetuned['avg_relevant_score'] - baseline['avg_relevant_score']) / baseline['avg_relevant_score'] * 100) if baseline['avg_relevant_score'] != 0 else 0 separation_improvement = ((finetuned['score_separation'] - baseline['score_separation']) / baseline['score_separation'] * 100) if baseline['score_separation'] != 0 else 0 ranking_improvement = ((finetuned['ranking_accuracy'] - baseline['ranking_accuracy']) / baseline['ranking_accuracy'] * 100) if baseline['ranking_accuracy'] != 0 else 0 return { 'relevance_improvement': relevance_improvement, 'separation_improvement': separation_improvement, 'ranking_improvement': ranking_improvement } def _compare_reranker_results(self, baseline, finetuned): """Compare reranker results""" if not baseline or not finetuned: return {} accuracy_improvement = ((finetuned['accuracy'] - baseline['accuracy']) / baseline['accuracy'] * 100) if baseline['accuracy'] != 0 else 0 return { 'accuracy_improvement': accuracy_improvement } def validate_pipeline(self, retriever_path, reranker_path): """Quick validation of full pipeline""" print(f"\n🚀 Validating Full Pipeline") print("=" * 50) try: # Load models retriever_tokenizer = AutoTokenizer.from_pretrained(retriever_path, trust_remote_code=True) retriever = BGEM3Model.from_pretrained(retriever_path) retriever.to(self.device) retriever.eval() reranker_tokenizer = AutoTokenizer.from_pretrained(reranker_path, trust_remote_code=True) reranker = BGERerankerModel.from_pretrained(reranker_path) reranker.to(self.device) reranker.eval() print(" ✅ Pipeline models loaded successfully") total_queries = 0 pipeline_accuracy = 0 with torch.no_grad(): for test_case in self.test_queries: query = test_case["query"] relevant_docs = test_case["relevant"] irrelevant_docs = test_case["irrelevant"] # Create candidate pool all_docs = relevant_docs + irrelevant_docs doc_labels = [1] * len(relevant_docs) + [0] * len(irrelevant_docs) # Step 1: Retrieval query_inputs = retriever_tokenizer( query, return_tensors='pt', padding=True, truncation=True, max_length=512 ).to(self.device) query_emb = retriever(**query_inputs, return_dense=True)['dense'] # Score all documents with retriever doc_scores = [] for doc in all_docs: doc_inputs = retriever_tokenizer( doc, return_tensors='pt', padding=True, truncation=True, max_length=512 ).to(self.device) doc_emb = retriever(**doc_inputs, return_dense=True)['dense'] score = torch.cosine_similarity(query_emb, doc_emb, dim=1).item() doc_scores.append(score) # Get top-k for reranking (all docs in this case) retrieval_indices = sorted(range(len(doc_scores)), key=lambda i: doc_scores[i], reverse=True) # Step 2: Reranking rerank_scores = [] rerank_labels = [] for idx in retrieval_indices: doc = all_docs[idx] label = doc_labels[idx] pair_text = f"{query} [SEP] {doc}" pair_inputs = reranker_tokenizer( pair_text, return_tensors='pt', padding=True, truncation=True, max_length=512 ).to(self.device) rerank_output = reranker(**pair_inputs) rerank_score = rerank_output.logits.item() rerank_scores.append(rerank_score) rerank_labels.append(label) # Check if top-ranked document is relevant best_idx = np.argmax(rerank_scores) if rerank_labels[best_idx] == 1: pipeline_accuracy += 1 total_queries += 1 pipeline_acc = pipeline_accuracy / total_queries if total_queries > 0 else 0 print(f"\n📊 Pipeline Validation Summary:") print(f" Pipeline Accuracy: {pipeline_acc:.4f}") print(f" Queries Tested: {total_queries}") if pipeline_acc > 0.5: print(" ✅ Pipeline performing well!") else: print(" ❌ Pipeline needs improvement!") return { 'accuracy': pipeline_acc, 'total_queries': total_queries, 'status': 'good' if pipeline_acc > 0.5 else 'needs_improvement' } except Exception as e: logger.error(f"Pipeline validation failed: {e}") return None def parse_args(): """Parse command line arguments""" parser = argparse.ArgumentParser( description="Quick validation of BGE fine-tuned models", formatter_class=argparse.RawDescriptionHelpFormatter ) parser.add_argument("--retriever_model", type=str, default=None, help="Path to fine-tuned retriever model") parser.add_argument("--reranker_model", type=str, default=None, help="Path to fine-tuned reranker model") parser.add_argument("--retriever_baseline", type=str, default="BAAI/bge-m3", help="Path to baseline retriever model") parser.add_argument("--reranker_baseline", type=str, default="BAAI/bge-reranker-base", help="Path to baseline reranker model") parser.add_argument("--test_pipeline", action="store_true", help="Test full retrieval + reranking pipeline") parser.add_argument("--output_file", type=str, default=None, help="Save results to JSON file") return parser.parse_args() def main(): """Main validation function""" args = parse_args() if not args.retriever_model and not args.reranker_model: print("❌ Error: At least one model must be specified") print(" Use --retriever_model or --reranker_model") return 1 print("🚀 BGE Quick Validation") print("=" * 60) print(f"⏰ Started at: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}") validator = QuickValidator() results = {} # Validate retriever if args.retriever_model: if os.path.exists(args.retriever_model): retriever_results = validator.validate_retriever( args.retriever_model, args.retriever_baseline ) if retriever_results: results['retriever'] = retriever_results else: print(f"❌ Retriever model not found: {args.retriever_model}") # Validate reranker if args.reranker_model: if os.path.exists(args.reranker_model): reranker_results = validator.validate_reranker( args.reranker_model, args.reranker_baseline ) if reranker_results: results['reranker'] = reranker_results else: print(f"❌ Reranker model not found: {args.reranker_model}") # Validate pipeline if (args.test_pipeline and args.retriever_model and args.reranker_model and os.path.exists(args.retriever_model) and os.path.exists(args.reranker_model)): pipeline_results = validator.validate_pipeline( args.retriever_model, args.reranker_model ) if pipeline_results: results['pipeline'] = pipeline_results # Print final summary print(f"\n{'='*60}") print("🎯 FINAL VALIDATION SUMMARY") print(f"{'='*60}") overall_status = "✅ SUCCESS" issues = [] for model_type, model_results in results.items(): if 'summary' in model_results: status = model_results['summary'].get('status', 'unknown') improvement = model_results['summary'].get('improvement_pct', 0) if status == 'improved': print(f"✅ {model_type.title()}: Improved by {improvement:.2f}%") elif status == 'degraded': print(f"❌ {model_type.title()}: Degraded by {abs(improvement):.2f}%") issues.append(model_type) overall_status = "⚠️ ISSUES DETECTED" elif model_type == 'pipeline' and 'status' in model_results: if model_results['status'] == 'good': print(f"✅ Pipeline: Working well (accuracy: {model_results['accuracy']:.3f})") else: print(f"❌ Pipeline: Needs improvement (accuracy: {model_results['accuracy']:.3f})") issues.append('pipeline') overall_status = "⚠️ ISSUES DETECTED" print(f"\n{overall_status}") if issues: print(f"⚠️ Issues detected in: {', '.join(issues)}") print("💡 Consider running comprehensive validation for detailed analysis") # Save results if requested if args.output_file: results['timestamp'] = datetime.now().isoformat() results['overall_status'] = overall_status results['issues'] = issues with open(args.output_file, 'w', encoding='utf-8') as f: json.dump(results, f, indent=2, ensure_ascii=False) print(f"📊 Results saved to: {args.output_file}") return 0 if overall_status == "✅ SUCCESS" else 1 if __name__ == "__main__": exit(main())