#!/usr/bin/env python3 """ Comprehensive Validation Script for BGE Fine-tuned Models This script provides a complete validation framework to determine if fine-tuned models actually perform better than baseline models across multiple datasets and metrics. Features: - Tests both M3 retriever and reranker models - Compares against baseline models on multiple datasets - Provides statistical significance testing - Tests pipeline performance (retrieval + reranking) - Generates comprehensive reports - Includes performance degradation detection - Supports cross-dataset validation Usage: python scripts/comprehensive_validation.py \\ --retriever_finetuned ./output/bge-m3-enhanced/final_model \\ --reranker_finetuned ./output/bge-reranker/final_model \\ --output_dir ./validation_results """ import os import sys import json import logging import argparse import pandas as pd import numpy as np from pathlib import Path from datetime import datetime from typing import Dict, List, Tuple, Optional, Any from dataclasses import dataclass from scipy import stats import torch from transformers import AutoTokenizer from torch.utils.data import DataLoader # Add project root to path sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from models.bge_m3 import BGEM3Model from models.bge_reranker import BGERerankerModel from data.dataset import BGEM3Dataset, BGERerankerDataset from evaluation.metrics import ( compute_retrieval_metrics, compute_reranker_metrics, compute_retrieval_metrics_comprehensive ) from evaluation.evaluator import RetrievalEvaluator, InteractiveEvaluator from data.preprocessing import DataPreprocessor from utils.logging import setup_logging, get_logger from utils.config_loader import get_config @dataclass class ValidationConfig: """Configuration for comprehensive validation""" # Model paths retriever_baseline: str = "BAAI/bge-m3" reranker_baseline: str = "BAAI/bge-reranker-base" retriever_finetuned: Optional[str] = None reranker_finetuned: Optional[str] = None # Test datasets test_datasets: List[str] = None # Evaluation settings batch_size: int = 32 max_length: int = 512 k_values: List[int] = None significance_level: float = 0.05 min_samples_for_significance: int = 30 # Pipeline settings retrieval_top_k: int = 100 rerank_top_k: int = 10 # Output settings output_dir: str = "./validation_results" save_embeddings: bool = False create_report: bool = True def __post_init__(self): if self.test_datasets is None: self.test_datasets = [ "data/datasets/examples/embedding_data.jsonl", "data/datasets/examples/reranker_data.jsonl", "data/datasets/Reranker_AFQMC/dev.json", "data/datasets/δΈε½ζΌδΉ/reranker_train.jsonl" ] if self.k_values is None: self.k_values = [1, 3, 5, 10, 20] class ComprehensiveValidator: """Comprehensive validation framework for BGE models""" def __init__(self, config: ValidationConfig): self.config = config self.logger = get_logger(__name__) self.results = {} self.device = self._get_device() # Create output directory os.makedirs(config.output_dir, exist_ok=True) # Setup logging setup_logging( output_dir=config.output_dir, log_level=logging.INFO, log_to_console=True, log_to_file=True ) def _get_device(self) -> torch.device: """Get optimal device for evaluation""" try: config = get_config() device_type = config.get('hardware', {}).get('device', 'cuda') if device_type == 'npu' and torch.cuda.is_available(): try: import torch_npu device = torch.device("npu:0") self.logger.info("Using NPU device") return device except ImportError: pass if torch.cuda.is_available(): device = torch.device("cuda") self.logger.info("Using CUDA device") return device except Exception as e: self.logger.warning(f"Error detecting optimal device: {e}") device = torch.device("cpu") self.logger.info("Using CPU device") return device def validate_all(self) -> Dict[str, Any]: """Run comprehensive validation on all models and datasets""" self.logger.info("π Starting Comprehensive BGE Model Validation") self.logger.info("=" * 80) start_time = datetime.now() try: # 1. Validate retriever models if self.config.retriever_finetuned: self.logger.info("π Validating Retriever Models...") retriever_results = self._validate_retrievers() self.results['retriever'] = retriever_results # 2. Validate reranker models if self.config.reranker_finetuned: self.logger.info("π Validating Reranker Models...") reranker_results = self._validate_rerankers() self.results['reranker'] = reranker_results # 3. Validate pipeline (if both models available) if self.config.retriever_finetuned and self.config.reranker_finetuned: self.logger.info("π Validating Pipeline Performance...") pipeline_results = self._validate_pipeline() self.results['pipeline'] = pipeline_results # 4. Generate comprehensive report if self.config.create_report: self._generate_report() # 5. Statistical analysis self._perform_statistical_analysis() end_time = datetime.now() duration = (end_time - start_time).total_seconds() self.logger.info("β Comprehensive validation completed!") self.logger.info(f"β±οΈ Total validation time: {duration:.2f} seconds") return self.results except Exception as e: self.logger.error(f"β Validation failed: {e}", exc_info=True) raise def _validate_retrievers(self) -> Dict[str, Any]: """Validate retriever models on all suitable datasets""" results = { 'baseline_metrics': {}, 'finetuned_metrics': {}, 'comparisons': {}, 'datasets_tested': [] } for dataset_path in self.config.test_datasets: if not os.path.exists(dataset_path): self.logger.warning(f"Dataset not found: {dataset_path}") continue # Skip reranker-only datasets for retriever testing if 'reranker' in dataset_path.lower() and 'AFQMC' in dataset_path: continue dataset_name = Path(dataset_path).stem self.logger.info(f"Testing retriever on dataset: {dataset_name}") try: # Test baseline model baseline_metrics = self._test_retriever_on_dataset( model_path=self.config.retriever_baseline, dataset_path=dataset_path, is_baseline=True ) # Test fine-tuned model finetuned_metrics = self._test_retriever_on_dataset( model_path=self.config.retriever_finetuned, dataset_path=dataset_path, is_baseline=False ) # Compare results comparison = self._compare_metrics(baseline_metrics, finetuned_metrics, "retriever") results['baseline_metrics'][dataset_name] = baseline_metrics results['finetuned_metrics'][dataset_name] = finetuned_metrics results['comparisons'][dataset_name] = comparison results['datasets_tested'].append(dataset_name) self.logger.info(f"β Retriever validation complete for {dataset_name}") except Exception as e: self.logger.error(f"β Failed to test retriever on {dataset_name}: {e}") continue return results def _validate_rerankers(self) -> Dict[str, Any]: """Validate reranker models on all suitable datasets""" results = { 'baseline_metrics': {}, 'finetuned_metrics': {}, 'comparisons': {}, 'datasets_tested': [] } for dataset_path in self.config.test_datasets: if not os.path.exists(dataset_path): self.logger.warning(f"Dataset not found: {dataset_path}") continue dataset_name = Path(dataset_path).stem self.logger.info(f"Testing reranker on dataset: {dataset_name}") try: # Test baseline model baseline_metrics = self._test_reranker_on_dataset( model_path=self.config.reranker_baseline, dataset_path=dataset_path, is_baseline=True ) # Test fine-tuned model finetuned_metrics = self._test_reranker_on_dataset( model_path=self.config.reranker_finetuned, dataset_path=dataset_path, is_baseline=False ) # Compare results comparison = self._compare_metrics(baseline_metrics, finetuned_metrics, "reranker") results['baseline_metrics'][dataset_name] = baseline_metrics results['finetuned_metrics'][dataset_name] = finetuned_metrics results['comparisons'][dataset_name] = comparison results['datasets_tested'].append(dataset_name) self.logger.info(f"β Reranker validation complete for {dataset_name}") except Exception as e: self.logger.error(f"β Failed to test reranker on {dataset_name}: {e}") continue return results def _validate_pipeline(self) -> Dict[str, Any]: """Validate end-to-end retrieval + reranking pipeline""" results = { 'baseline_pipeline': {}, 'finetuned_pipeline': {}, 'comparisons': {}, 'datasets_tested': [] } for dataset_path in self.config.test_datasets: if not os.path.exists(dataset_path): continue # Only test on datasets suitable for pipeline evaluation if 'embedding_data' not in dataset_path: continue dataset_name = Path(dataset_path).stem self.logger.info(f"Testing pipeline on dataset: {dataset_name}") try: # Test baseline pipeline baseline_metrics = self._test_pipeline_on_dataset( retriever_path=self.config.retriever_baseline, reranker_path=self.config.reranker_baseline, dataset_path=dataset_path, is_baseline=True ) # Test fine-tuned pipeline finetuned_metrics = self._test_pipeline_on_dataset( retriever_path=self.config.retriever_finetuned, reranker_path=self.config.reranker_finetuned, dataset_path=dataset_path, is_baseline=False ) # Compare results comparison = self._compare_metrics(baseline_metrics, finetuned_metrics, "pipeline") results['baseline_pipeline'][dataset_name] = baseline_metrics results['finetuned_pipeline'][dataset_name] = finetuned_metrics results['comparisons'][dataset_name] = comparison results['datasets_tested'].append(dataset_name) self.logger.info(f"β Pipeline validation complete for {dataset_name}") except Exception as e: self.logger.error(f"β Failed to test pipeline on {dataset_name}: {e}") continue return results def _test_retriever_on_dataset(self, model_path: str, dataset_path: str, is_baseline: bool) -> Dict[str, float]: """Test a retriever model on a specific dataset""" try: # Load model and tokenizer tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) model = BGEM3Model.from_pretrained(model_path) model.to(self.device) model.eval() # Create dataset dataset = BGEM3Dataset( data_path=dataset_path, tokenizer=tokenizer, query_max_length=self.config.max_length, passage_max_length=self.config.max_length, is_train=False ) dataloader = DataLoader( dataset, batch_size=self.config.batch_size, shuffle=False ) # Create evaluator evaluator = RetrievalEvaluator( model=model, tokenizer=tokenizer, device=self.device, k_values=self.config.k_values ) # Run evaluation metrics = evaluator.evaluate( dataloader=dataloader, desc=f"{'Baseline' if is_baseline else 'Fine-tuned'} Retriever" ) return metrics except Exception as e: self.logger.error(f"Error testing retriever: {e}") return {} def _test_reranker_on_dataset(self, model_path: str, dataset_path: str, is_baseline: bool) -> Dict[str, float]: """Test a reranker model on a specific dataset""" try: # Load model and tokenizer tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) model = BGERerankerModel.from_pretrained(model_path) model.to(self.device) model.eval() # Create dataset dataset = BGERerankerDataset( data_path=dataset_path, tokenizer=tokenizer, query_max_length=self.config.max_length, passage_max_length=self.config.max_length, is_train=False ) dataloader = DataLoader( dataset, batch_size=self.config.batch_size, shuffle=False ) # Collect predictions and labels all_scores = [] all_labels = [] with torch.no_grad(): for batch in dataloader: # Move to device input_ids = batch['input_ids'].to(self.device) attention_mask = batch['attention_mask'].to(self.device) labels = batch['labels'].cpu().numpy() # Get predictions outputs = model(input_ids=input_ids, attention_mask=attention_mask) scores = outputs.logits.cpu().numpy() all_scores.append(scores) all_labels.append(labels) # Concatenate results all_scores = np.concatenate(all_scores, axis=0) all_labels = np.concatenate(all_labels, axis=0) # Compute metrics metrics = compute_reranker_metrics( scores=all_scores, labels=all_labels, k_values=self.config.k_values ) return metrics except Exception as e: self.logger.error(f"Error testing reranker: {e}") return {} def _test_pipeline_on_dataset(self, retriever_path: str, reranker_path: str, dataset_path: str, is_baseline: bool) -> Dict[str, float]: """Test end-to-end pipeline on a dataset""" try: # Load models retriever_tokenizer = AutoTokenizer.from_pretrained(retriever_path, trust_remote_code=True) retriever = BGEM3Model.from_pretrained(retriever_path) retriever.to(self.device) retriever.eval() reranker_tokenizer = AutoTokenizer.from_pretrained(reranker_path, trust_remote_code=True) reranker = BGERerankerModel.from_pretrained(reranker_path) reranker.to(self.device) reranker.eval() # Load and preprocess data data = [] with open(dataset_path, 'r', encoding='utf-8') as f: for line in f: try: item = json.loads(line.strip()) if 'query' in item and 'pos' in item and 'neg' in item: data.append(item) except: continue if not data: return {} # Test pipeline performance total_queries = 0 pipeline_metrics = {f'recall@{k}': [] for k in self.config.k_values} pipeline_metrics.update({f'mrr@{k}': [] for k in self.config.k_values}) for item in data[:100]: # Limit to 100 queries for speed query = item['query'] pos_docs = item['pos'][:2] # Take first 2 positive docs neg_docs = item['neg'][:8] # Take first 8 negative docs # Create candidate pool candidates = pos_docs + neg_docs labels = [1] * len(pos_docs) + [0] * len(neg_docs) if len(candidates) < 3: # Need minimum candidates continue try: # Step 1: Retrieval (encode query and candidates) query_inputs = retriever_tokenizer( query, return_tensors='pt', padding=True, truncation=True, max_length=self.config.max_length ).to(self.device) with torch.no_grad(): query_outputs = retriever(**query_inputs, return_dense=True) query_emb = query_outputs['dense'] # Encode candidates candidate_embs = [] for candidate in candidates: cand_inputs = retriever_tokenizer( candidate, return_tensors='pt', padding=True, truncation=True, max_length=self.config.max_length ).to(self.device) with torch.no_grad(): cand_outputs = retriever(**cand_inputs, return_dense=True) candidate_embs.append(cand_outputs['dense']) candidate_embs = torch.cat(candidate_embs, dim=0) # Compute similarities similarities = torch.cosine_similarity(query_emb, candidate_embs, dim=1) # Get top-k for retrieval top_k_retrieval = min(self.config.retrieval_top_k, len(candidates)) retrieval_indices = torch.argsort(similarities, descending=True)[:top_k_retrieval] # Step 2: Reranking rerank_candidates = [candidates[i] for i in retrieval_indices] rerank_labels = [labels[i] for i in retrieval_indices] # Score pairs with reranker rerank_scores = [] for candidate in rerank_candidates: pair_text = f"{query} [SEP] {candidate}" pair_inputs = reranker_tokenizer( pair_text, return_tensors='pt', padding=True, truncation=True, max_length=self.config.max_length ).to(self.device) with torch.no_grad(): pair_outputs = reranker(**pair_inputs) score = pair_outputs.logits.item() rerank_scores.append(score) # Final ranking final_indices = np.argsort(rerank_scores)[::-1] final_labels = [rerank_labels[i] for i in final_indices] # Compute metrics for k in self.config.k_values: if k <= len(final_labels): # Recall@k relevant_in_topk = sum(final_labels[:k]) total_relevant = sum(rerank_labels) recall = relevant_in_topk / total_relevant if total_relevant > 0 else 0 pipeline_metrics[f'recall@{k}'].append(recall) # MRR@k for i, label in enumerate(final_labels[:k]): if label == 1: pipeline_metrics[f'mrr@{k}'].append(1.0 / (i + 1)) break else: pipeline_metrics[f'mrr@{k}'].append(0.0) total_queries += 1 except Exception as e: self.logger.warning(f"Error processing query: {e}") continue # Average metrics final_metrics = {} for metric, values in pipeline_metrics.items(): if values: final_metrics[metric] = np.mean(values) else: final_metrics[metric] = 0.0 final_metrics['total_queries'] = total_queries return final_metrics except Exception as e: self.logger.error(f"Error testing pipeline: {e}") return {} def _compare_metrics(self, baseline: Dict[str, float], finetuned: Dict[str, float], model_type: str) -> Dict[str, Any]: """Compare baseline vs fine-tuned metrics with statistical analysis""" comparison = { 'improvements': {}, 'degradations': {}, 'statistical_significance': {}, 'summary': {} } # Compare common metrics common_metrics = set(baseline.keys()) & set(finetuned.keys()) improvements = 0 degradations = 0 significant_improvements = 0 for metric in common_metrics: if isinstance(baseline[metric], (int, float)) and isinstance(finetuned[metric], (int, float)): baseline_val = baseline[metric] finetuned_val = finetuned[metric] # Calculate improvement if baseline_val != 0: improvement_pct = ((finetuned_val - baseline_val) / baseline_val) * 100 else: improvement_pct = float('inf') if finetuned_val > 0 else 0 # Classify improvement/degradation if finetuned_val > baseline_val: comparison['improvements'][metric] = { 'baseline': baseline_val, 'finetuned': finetuned_val, 'improvement_pct': improvement_pct, 'absolute_improvement': finetuned_val - baseline_val } improvements += 1 elif finetuned_val < baseline_val: comparison['degradations'][metric] = { 'baseline': baseline_val, 'finetuned': finetuned_val, 'degradation_pct': -improvement_pct, 'absolute_degradation': baseline_val - finetuned_val } degradations += 1 # Summary statistics comparison['summary'] = { 'total_metrics': len(common_metrics), 'improvements': improvements, 'degradations': degradations, 'unchanged': len(common_metrics) - improvements - degradations, 'improvement_ratio': improvements / len(common_metrics) if common_metrics else 0, 'net_improvement': improvements - degradations } return comparison def _perform_statistical_analysis(self): """Perform statistical analysis on validation results""" self.logger.info("π Performing Statistical Analysis...") analysis = { 'overall_summary': {}, 'significance_tests': {}, 'confidence_intervals': {}, 'effect_sizes': {} } # Analyze each model type for model_type in ['retriever', 'reranker', 'pipeline']: if model_type not in self.results: continue model_results = self.results[model_type] # Collect all improvements across datasets all_improvements = [] significant_improvements = 0 for dataset, comparison in model_results.get('comparisons', {}).items(): for metric, improvement_data in comparison.get('improvements', {}).items(): improvement_pct = improvement_data.get('improvement_pct', 0) if improvement_pct > 0: all_improvements.append(improvement_pct) # Simple significance test (can be enhanced) if improvement_pct > 5.0: # 5% improvement threshold significant_improvements += 1 if all_improvements: analysis[model_type] = { 'mean_improvement_pct': np.mean(all_improvements), 'median_improvement_pct': np.median(all_improvements), 'std_improvement_pct': np.std(all_improvements), 'min_improvement_pct': np.min(all_improvements), 'max_improvement_pct': np.max(all_improvements), 'significant_improvements': significant_improvements, 'total_improvements': len(all_improvements) } self.results['statistical_analysis'] = analysis def _generate_report(self): """Generate comprehensive validation report""" self.logger.info("π Generating Comprehensive Report...") report_path = os.path.join(self.config.output_dir, "validation_report.html") html_content = self._create_html_report() with open(report_path, 'w', encoding='utf-8') as f: f.write(html_content) # Also save JSON results json_path = os.path.join(self.config.output_dir, "validation_results.json") with open(json_path, 'w', encoding='utf-8') as f: json.dump(self.results, f, indent=2, default=str) self.logger.info(f"π Report saved to: {report_path}") self.logger.info(f"π JSON results saved to: {json_path}") def _create_html_report(self) -> str: """Create HTML report with comprehensive validation results""" html = f"""
Generated: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}
Configuration:
| Metric | Baseline | Fine-tuned | Change | Status |
|---|---|---|---|---|
| {metric} | {data['baseline']:.4f} | {data['finetuned']:.4f} | +{data['improvement_pct']:.2f}% | β Improved |
| {metric} | {data['baseline']:.4f} | {data['finetuned']:.4f} | -{data['degradation_pct']:.2f}% | β Degraded |