990 lines
39 KiB
Python
990 lines
39 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Comprehensive Validation Script for BGE Fine-tuned Models
|
|
|
|
This script provides a complete validation framework to determine if fine-tuned models
|
|
actually perform better than baseline models across multiple datasets and metrics.
|
|
|
|
Features:
|
|
- Tests both M3 retriever and reranker models
|
|
- Compares against baseline models on multiple datasets
|
|
- Provides statistical significance testing
|
|
- Tests pipeline performance (retrieval + reranking)
|
|
- Generates comprehensive reports
|
|
- Includes performance degradation detection
|
|
- Supports cross-dataset validation
|
|
|
|
Usage:
|
|
python scripts/comprehensive_validation.py \\
|
|
--retriever_finetuned ./output/bge-m3-enhanced/final_model \\
|
|
--reranker_finetuned ./output/bge-reranker/final_model \\
|
|
--output_dir ./validation_results
|
|
"""
|
|
|
|
import os
|
|
import sys
|
|
import json
|
|
import logging
|
|
import argparse
|
|
import pandas as pd
|
|
import numpy as np
|
|
from pathlib import Path
|
|
from datetime import datetime
|
|
from typing import Dict, List, Tuple, Optional, Any
|
|
from dataclasses import dataclass
|
|
from scipy import stats
|
|
import torch
|
|
from transformers import AutoTokenizer
|
|
from torch.utils.data import DataLoader
|
|
|
|
# Add project root to path
|
|
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
|
|
|
from models.bge_m3 import BGEM3Model
|
|
from models.bge_reranker import BGERerankerModel
|
|
from data.dataset import BGEM3Dataset, BGERerankerDataset
|
|
from evaluation.metrics import (
|
|
compute_retrieval_metrics, compute_reranker_metrics,
|
|
compute_retrieval_metrics_comprehensive
|
|
)
|
|
from evaluation.evaluator import RetrievalEvaluator, InteractiveEvaluator
|
|
from data.preprocessing import DataPreprocessor
|
|
from utils.logging import setup_logging, get_logger
|
|
from utils.config_loader import get_config
|
|
|
|
|
|
@dataclass
|
|
class ValidationConfig:
|
|
"""Configuration for comprehensive validation"""
|
|
# Model paths
|
|
retriever_baseline: str = "BAAI/bge-m3"
|
|
reranker_baseline: str = "BAAI/bge-reranker-base"
|
|
retriever_finetuned: Optional[str] = None
|
|
reranker_finetuned: Optional[str] = None
|
|
|
|
# Test datasets
|
|
test_datasets: List[str] = None
|
|
|
|
# Evaluation settings
|
|
batch_size: int = 32
|
|
max_length: int = 512
|
|
k_values: List[int] = None
|
|
significance_level: float = 0.05
|
|
min_samples_for_significance: int = 30
|
|
|
|
# Pipeline settings
|
|
retrieval_top_k: int = 100
|
|
rerank_top_k: int = 10
|
|
|
|
# Output settings
|
|
output_dir: str = "./validation_results"
|
|
save_embeddings: bool = False
|
|
create_report: bool = True
|
|
|
|
def __post_init__(self):
|
|
if self.test_datasets is None:
|
|
self.test_datasets = [
|
|
"data/datasets/examples/embedding_data.jsonl",
|
|
"data/datasets/examples/reranker_data.jsonl",
|
|
"data/datasets/Reranker_AFQMC/dev.json",
|
|
"data/datasets/三国演义/reranker_train.jsonl"
|
|
]
|
|
if self.k_values is None:
|
|
self.k_values = [1, 3, 5, 10, 20]
|
|
|
|
|
|
class ComprehensiveValidator:
|
|
"""Comprehensive validation framework for BGE models"""
|
|
|
|
def __init__(self, config: ValidationConfig):
|
|
self.config = config
|
|
self.logger = get_logger(__name__)
|
|
self.results = {}
|
|
self.device = self._get_device()
|
|
|
|
# Create output directory
|
|
os.makedirs(config.output_dir, exist_ok=True)
|
|
|
|
# Setup logging
|
|
setup_logging(
|
|
output_dir=config.output_dir,
|
|
log_level=logging.INFO,
|
|
log_to_console=True,
|
|
log_to_file=True
|
|
)
|
|
|
|
def _get_device(self) -> torch.device:
|
|
"""Get optimal device for evaluation"""
|
|
try:
|
|
config = get_config()
|
|
device_type = config.get('hardware', {}).get('device', 'cuda')
|
|
|
|
if device_type == 'npu' and torch.cuda.is_available():
|
|
try:
|
|
import torch_npu
|
|
device = torch.device("npu:0")
|
|
self.logger.info("Using NPU device")
|
|
return device
|
|
except ImportError:
|
|
pass
|
|
|
|
if torch.cuda.is_available():
|
|
device = torch.device("cuda")
|
|
self.logger.info("Using CUDA device")
|
|
return device
|
|
|
|
except Exception as e:
|
|
self.logger.warning(f"Error detecting optimal device: {e}")
|
|
|
|
device = torch.device("cpu")
|
|
self.logger.info("Using CPU device")
|
|
return device
|
|
|
|
def validate_all(self) -> Dict[str, Any]:
|
|
"""Run comprehensive validation on all models and datasets"""
|
|
self.logger.info("🚀 Starting Comprehensive BGE Model Validation")
|
|
self.logger.info("=" * 80)
|
|
|
|
start_time = datetime.now()
|
|
|
|
try:
|
|
# 1. Validate retriever models
|
|
if self.config.retriever_finetuned:
|
|
self.logger.info("📊 Validating Retriever Models...")
|
|
retriever_results = self._validate_retrievers()
|
|
self.results['retriever'] = retriever_results
|
|
|
|
# 2. Validate reranker models
|
|
if self.config.reranker_finetuned:
|
|
self.logger.info("📊 Validating Reranker Models...")
|
|
reranker_results = self._validate_rerankers()
|
|
self.results['reranker'] = reranker_results
|
|
|
|
# 3. Validate pipeline (if both models available)
|
|
if self.config.retriever_finetuned and self.config.reranker_finetuned:
|
|
self.logger.info("📊 Validating Pipeline Performance...")
|
|
pipeline_results = self._validate_pipeline()
|
|
self.results['pipeline'] = pipeline_results
|
|
|
|
# 4. Generate comprehensive report
|
|
if self.config.create_report:
|
|
self._generate_report()
|
|
|
|
# 5. Statistical analysis
|
|
self._perform_statistical_analysis()
|
|
|
|
end_time = datetime.now()
|
|
duration = (end_time - start_time).total_seconds()
|
|
|
|
self.logger.info("✅ Comprehensive validation completed!")
|
|
self.logger.info(f"⏱️ Total validation time: {duration:.2f} seconds")
|
|
|
|
return self.results
|
|
|
|
except Exception as e:
|
|
self.logger.error(f"❌ Validation failed: {e}", exc_info=True)
|
|
raise
|
|
|
|
def _validate_retrievers(self) -> Dict[str, Any]:
|
|
"""Validate retriever models on all suitable datasets"""
|
|
results = {
|
|
'baseline_metrics': {},
|
|
'finetuned_metrics': {},
|
|
'comparisons': {},
|
|
'datasets_tested': []
|
|
}
|
|
|
|
for dataset_path in self.config.test_datasets:
|
|
if not os.path.exists(dataset_path):
|
|
self.logger.warning(f"Dataset not found: {dataset_path}")
|
|
continue
|
|
|
|
# Skip reranker-only datasets for retriever testing
|
|
if 'reranker' in dataset_path.lower() and 'AFQMC' in dataset_path:
|
|
continue
|
|
|
|
dataset_name = Path(dataset_path).stem
|
|
self.logger.info(f"Testing retriever on dataset: {dataset_name}")
|
|
|
|
try:
|
|
# Test baseline model
|
|
baseline_metrics = self._test_retriever_on_dataset(
|
|
model_path=self.config.retriever_baseline,
|
|
dataset_path=dataset_path,
|
|
is_baseline=True
|
|
)
|
|
|
|
# Test fine-tuned model
|
|
finetuned_metrics = self._test_retriever_on_dataset(
|
|
model_path=self.config.retriever_finetuned,
|
|
dataset_path=dataset_path,
|
|
is_baseline=False
|
|
)
|
|
|
|
# Compare results
|
|
comparison = self._compare_metrics(baseline_metrics, finetuned_metrics, "retriever")
|
|
|
|
results['baseline_metrics'][dataset_name] = baseline_metrics
|
|
results['finetuned_metrics'][dataset_name] = finetuned_metrics
|
|
results['comparisons'][dataset_name] = comparison
|
|
results['datasets_tested'].append(dataset_name)
|
|
|
|
self.logger.info(f"✅ Retriever validation complete for {dataset_name}")
|
|
|
|
except Exception as e:
|
|
self.logger.error(f"❌ Failed to test retriever on {dataset_name}: {e}")
|
|
continue
|
|
|
|
return results
|
|
|
|
def _validate_rerankers(self) -> Dict[str, Any]:
|
|
"""Validate reranker models on all suitable datasets"""
|
|
results = {
|
|
'baseline_metrics': {},
|
|
'finetuned_metrics': {},
|
|
'comparisons': {},
|
|
'datasets_tested': []
|
|
}
|
|
|
|
for dataset_path in self.config.test_datasets:
|
|
if not os.path.exists(dataset_path):
|
|
self.logger.warning(f"Dataset not found: {dataset_path}")
|
|
continue
|
|
|
|
dataset_name = Path(dataset_path).stem
|
|
self.logger.info(f"Testing reranker on dataset: {dataset_name}")
|
|
|
|
try:
|
|
# Test baseline model
|
|
baseline_metrics = self._test_reranker_on_dataset(
|
|
model_path=self.config.reranker_baseline,
|
|
dataset_path=dataset_path,
|
|
is_baseline=True
|
|
)
|
|
|
|
# Test fine-tuned model
|
|
finetuned_metrics = self._test_reranker_on_dataset(
|
|
model_path=self.config.reranker_finetuned,
|
|
dataset_path=dataset_path,
|
|
is_baseline=False
|
|
)
|
|
|
|
# Compare results
|
|
comparison = self._compare_metrics(baseline_metrics, finetuned_metrics, "reranker")
|
|
|
|
results['baseline_metrics'][dataset_name] = baseline_metrics
|
|
results['finetuned_metrics'][dataset_name] = finetuned_metrics
|
|
results['comparisons'][dataset_name] = comparison
|
|
results['datasets_tested'].append(dataset_name)
|
|
|
|
self.logger.info(f"✅ Reranker validation complete for {dataset_name}")
|
|
|
|
except Exception as e:
|
|
self.logger.error(f"❌ Failed to test reranker on {dataset_name}: {e}")
|
|
continue
|
|
|
|
return results
|
|
|
|
def _validate_pipeline(self) -> Dict[str, Any]:
|
|
"""Validate end-to-end retrieval + reranking pipeline"""
|
|
results = {
|
|
'baseline_pipeline': {},
|
|
'finetuned_pipeline': {},
|
|
'comparisons': {},
|
|
'datasets_tested': []
|
|
}
|
|
|
|
for dataset_path in self.config.test_datasets:
|
|
if not os.path.exists(dataset_path):
|
|
continue
|
|
|
|
# Only test on datasets suitable for pipeline evaluation
|
|
if 'embedding_data' not in dataset_path:
|
|
continue
|
|
|
|
dataset_name = Path(dataset_path).stem
|
|
self.logger.info(f"Testing pipeline on dataset: {dataset_name}")
|
|
|
|
try:
|
|
# Test baseline pipeline
|
|
baseline_metrics = self._test_pipeline_on_dataset(
|
|
retriever_path=self.config.retriever_baseline,
|
|
reranker_path=self.config.reranker_baseline,
|
|
dataset_path=dataset_path,
|
|
is_baseline=True
|
|
)
|
|
|
|
# Test fine-tuned pipeline
|
|
finetuned_metrics = self._test_pipeline_on_dataset(
|
|
retriever_path=self.config.retriever_finetuned,
|
|
reranker_path=self.config.reranker_finetuned,
|
|
dataset_path=dataset_path,
|
|
is_baseline=False
|
|
)
|
|
|
|
# Compare results
|
|
comparison = self._compare_metrics(baseline_metrics, finetuned_metrics, "pipeline")
|
|
|
|
results['baseline_pipeline'][dataset_name] = baseline_metrics
|
|
results['finetuned_pipeline'][dataset_name] = finetuned_metrics
|
|
results['comparisons'][dataset_name] = comparison
|
|
results['datasets_tested'].append(dataset_name)
|
|
|
|
self.logger.info(f"✅ Pipeline validation complete for {dataset_name}")
|
|
|
|
except Exception as e:
|
|
self.logger.error(f"❌ Failed to test pipeline on {dataset_name}: {e}")
|
|
continue
|
|
|
|
return results
|
|
|
|
def _test_retriever_on_dataset(self, model_path: str, dataset_path: str, is_baseline: bool) -> Dict[str, float]:
|
|
"""Test a retriever model on a specific dataset"""
|
|
try:
|
|
# Load model and tokenizer
|
|
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
|
model = BGEM3Model.from_pretrained(model_path)
|
|
model.to(self.device)
|
|
model.eval()
|
|
|
|
# Create dataset
|
|
dataset = BGEM3Dataset(
|
|
data_path=dataset_path,
|
|
tokenizer=tokenizer,
|
|
query_max_length=self.config.max_length,
|
|
passage_max_length=self.config.max_length,
|
|
is_train=False
|
|
)
|
|
|
|
dataloader = DataLoader(
|
|
dataset,
|
|
batch_size=self.config.batch_size,
|
|
shuffle=False
|
|
)
|
|
|
|
# Create evaluator
|
|
evaluator = RetrievalEvaluator(
|
|
model=model,
|
|
tokenizer=tokenizer,
|
|
device=self.device,
|
|
k_values=self.config.k_values
|
|
)
|
|
|
|
# Run evaluation
|
|
metrics = evaluator.evaluate(
|
|
dataloader=dataloader,
|
|
desc=f"{'Baseline' if is_baseline else 'Fine-tuned'} Retriever"
|
|
)
|
|
|
|
return metrics
|
|
|
|
except Exception as e:
|
|
self.logger.error(f"Error testing retriever: {e}")
|
|
return {}
|
|
|
|
def _test_reranker_on_dataset(self, model_path: str, dataset_path: str, is_baseline: bool) -> Dict[str, float]:
|
|
"""Test a reranker model on a specific dataset"""
|
|
try:
|
|
# Load model and tokenizer
|
|
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
|
model = BGERerankerModel.from_pretrained(model_path)
|
|
model.to(self.device)
|
|
model.eval()
|
|
|
|
# Create dataset
|
|
dataset = BGERerankerDataset(
|
|
data_path=dataset_path,
|
|
tokenizer=tokenizer,
|
|
query_max_length=self.config.max_length,
|
|
passage_max_length=self.config.max_length,
|
|
is_train=False
|
|
)
|
|
|
|
dataloader = DataLoader(
|
|
dataset,
|
|
batch_size=self.config.batch_size,
|
|
shuffle=False
|
|
)
|
|
|
|
# Collect predictions and labels
|
|
all_scores = []
|
|
all_labels = []
|
|
|
|
with torch.no_grad():
|
|
for batch in dataloader:
|
|
# Move to device
|
|
input_ids = batch['input_ids'].to(self.device)
|
|
attention_mask = batch['attention_mask'].to(self.device)
|
|
labels = batch['labels'].cpu().numpy()
|
|
|
|
# Get predictions
|
|
outputs = model(input_ids=input_ids, attention_mask=attention_mask)
|
|
scores = outputs.logits.cpu().numpy()
|
|
|
|
all_scores.append(scores)
|
|
all_labels.append(labels)
|
|
|
|
# Concatenate results
|
|
all_scores = np.concatenate(all_scores, axis=0)
|
|
all_labels = np.concatenate(all_labels, axis=0)
|
|
|
|
# Compute metrics
|
|
metrics = compute_reranker_metrics(
|
|
scores=all_scores,
|
|
labels=all_labels,
|
|
k_values=self.config.k_values
|
|
)
|
|
|
|
return metrics
|
|
|
|
except Exception as e:
|
|
self.logger.error(f"Error testing reranker: {e}")
|
|
return {}
|
|
|
|
def _test_pipeline_on_dataset(self, retriever_path: str, reranker_path: str, dataset_path: str, is_baseline: bool) -> Dict[str, float]:
|
|
"""Test end-to-end pipeline on a dataset"""
|
|
try:
|
|
# Load models
|
|
retriever_tokenizer = AutoTokenizer.from_pretrained(retriever_path, trust_remote_code=True)
|
|
retriever = BGEM3Model.from_pretrained(retriever_path)
|
|
retriever.to(self.device)
|
|
retriever.eval()
|
|
|
|
reranker_tokenizer = AutoTokenizer.from_pretrained(reranker_path, trust_remote_code=True)
|
|
reranker = BGERerankerModel.from_pretrained(reranker_path)
|
|
reranker.to(self.device)
|
|
reranker.eval()
|
|
|
|
# Load and preprocess data
|
|
data = []
|
|
with open(dataset_path, 'r', encoding='utf-8') as f:
|
|
for line in f:
|
|
try:
|
|
item = json.loads(line.strip())
|
|
if 'query' in item and 'pos' in item and 'neg' in item:
|
|
data.append(item)
|
|
except:
|
|
continue
|
|
|
|
if not data:
|
|
return {}
|
|
|
|
# Test pipeline performance
|
|
total_queries = 0
|
|
pipeline_metrics = {f'recall@{k}': [] for k in self.config.k_values}
|
|
pipeline_metrics.update({f'mrr@{k}': [] for k in self.config.k_values})
|
|
|
|
for item in data[:100]: # Limit to 100 queries for speed
|
|
query = item['query']
|
|
pos_docs = item['pos'][:2] # Take first 2 positive docs
|
|
neg_docs = item['neg'][:8] # Take first 8 negative docs
|
|
|
|
# Create candidate pool
|
|
candidates = pos_docs + neg_docs
|
|
labels = [1] * len(pos_docs) + [0] * len(neg_docs)
|
|
|
|
if len(candidates) < 3: # Need minimum candidates
|
|
continue
|
|
|
|
try:
|
|
# Step 1: Retrieval (encode query and candidates)
|
|
query_inputs = retriever_tokenizer(
|
|
query,
|
|
return_tensors='pt',
|
|
padding=True,
|
|
truncation=True,
|
|
max_length=self.config.max_length
|
|
).to(self.device)
|
|
|
|
with torch.no_grad():
|
|
query_outputs = retriever(**query_inputs, return_dense=True)
|
|
query_emb = query_outputs['dense']
|
|
|
|
# Encode candidates
|
|
candidate_embs = []
|
|
for candidate in candidates:
|
|
cand_inputs = retriever_tokenizer(
|
|
candidate,
|
|
return_tensors='pt',
|
|
padding=True,
|
|
truncation=True,
|
|
max_length=self.config.max_length
|
|
).to(self.device)
|
|
|
|
with torch.no_grad():
|
|
cand_outputs = retriever(**cand_inputs, return_dense=True)
|
|
candidate_embs.append(cand_outputs['dense'])
|
|
|
|
candidate_embs = torch.cat(candidate_embs, dim=0)
|
|
|
|
# Compute similarities
|
|
similarities = torch.cosine_similarity(query_emb, candidate_embs, dim=1)
|
|
|
|
# Get top-k for retrieval
|
|
top_k_retrieval = min(self.config.retrieval_top_k, len(candidates))
|
|
retrieval_indices = torch.argsort(similarities, descending=True)[:top_k_retrieval]
|
|
|
|
# Step 2: Reranking
|
|
rerank_candidates = [candidates[i] for i in retrieval_indices]
|
|
rerank_labels = [labels[i] for i in retrieval_indices]
|
|
|
|
# Score pairs with reranker
|
|
rerank_scores = []
|
|
for candidate in rerank_candidates:
|
|
pair_text = f"{query} [SEP] {candidate}"
|
|
pair_inputs = reranker_tokenizer(
|
|
pair_text,
|
|
return_tensors='pt',
|
|
padding=True,
|
|
truncation=True,
|
|
max_length=self.config.max_length
|
|
).to(self.device)
|
|
|
|
with torch.no_grad():
|
|
pair_outputs = reranker(**pair_inputs)
|
|
score = pair_outputs.logits.item()
|
|
rerank_scores.append(score)
|
|
|
|
# Final ranking
|
|
final_indices = np.argsort(rerank_scores)[::-1]
|
|
final_labels = [rerank_labels[i] for i in final_indices]
|
|
|
|
# Compute metrics
|
|
for k in self.config.k_values:
|
|
if k <= len(final_labels):
|
|
# Recall@k
|
|
relevant_in_topk = sum(final_labels[:k])
|
|
total_relevant = sum(rerank_labels)
|
|
recall = relevant_in_topk / total_relevant if total_relevant > 0 else 0
|
|
pipeline_metrics[f'recall@{k}'].append(recall)
|
|
|
|
# MRR@k
|
|
for i, label in enumerate(final_labels[:k]):
|
|
if label == 1:
|
|
pipeline_metrics[f'mrr@{k}'].append(1.0 / (i + 1))
|
|
break
|
|
else:
|
|
pipeline_metrics[f'mrr@{k}'].append(0.0)
|
|
|
|
total_queries += 1
|
|
|
|
except Exception as e:
|
|
self.logger.warning(f"Error processing query: {e}")
|
|
continue
|
|
|
|
# Average metrics
|
|
final_metrics = {}
|
|
for metric, values in pipeline_metrics.items():
|
|
if values:
|
|
final_metrics[metric] = np.mean(values)
|
|
else:
|
|
final_metrics[metric] = 0.0
|
|
|
|
final_metrics['total_queries'] = total_queries
|
|
|
|
return final_metrics
|
|
|
|
except Exception as e:
|
|
self.logger.error(f"Error testing pipeline: {e}")
|
|
return {}
|
|
|
|
def _compare_metrics(self, baseline: Dict[str, float], finetuned: Dict[str, float], model_type: str) -> Dict[str, Any]:
|
|
"""Compare baseline vs fine-tuned metrics with statistical analysis"""
|
|
comparison = {
|
|
'improvements': {},
|
|
'degradations': {},
|
|
'statistical_significance': {},
|
|
'summary': {}
|
|
}
|
|
|
|
# Compare common metrics
|
|
common_metrics = set(baseline.keys()) & set(finetuned.keys())
|
|
|
|
improvements = 0
|
|
degradations = 0
|
|
significant_improvements = 0
|
|
|
|
for metric in common_metrics:
|
|
if isinstance(baseline[metric], (int, float)) and isinstance(finetuned[metric], (int, float)):
|
|
baseline_val = baseline[metric]
|
|
finetuned_val = finetuned[metric]
|
|
|
|
# Calculate improvement
|
|
if baseline_val != 0:
|
|
improvement_pct = ((finetuned_val - baseline_val) / baseline_val) * 100
|
|
else:
|
|
improvement_pct = float('inf') if finetuned_val > 0 else 0
|
|
|
|
# Classify improvement/degradation
|
|
if finetuned_val > baseline_val:
|
|
comparison['improvements'][metric] = {
|
|
'baseline': baseline_val,
|
|
'finetuned': finetuned_val,
|
|
'improvement_pct': improvement_pct,
|
|
'absolute_improvement': finetuned_val - baseline_val
|
|
}
|
|
improvements += 1
|
|
elif finetuned_val < baseline_val:
|
|
comparison['degradations'][metric] = {
|
|
'baseline': baseline_val,
|
|
'finetuned': finetuned_val,
|
|
'degradation_pct': -improvement_pct,
|
|
'absolute_degradation': baseline_val - finetuned_val
|
|
}
|
|
degradations += 1
|
|
|
|
# Summary statistics
|
|
comparison['summary'] = {
|
|
'total_metrics': len(common_metrics),
|
|
'improvements': improvements,
|
|
'degradations': degradations,
|
|
'unchanged': len(common_metrics) - improvements - degradations,
|
|
'improvement_ratio': improvements / len(common_metrics) if common_metrics else 0,
|
|
'net_improvement': improvements - degradations
|
|
}
|
|
|
|
return comparison
|
|
|
|
def _perform_statistical_analysis(self):
|
|
"""Perform statistical analysis on validation results"""
|
|
self.logger.info("📈 Performing Statistical Analysis...")
|
|
|
|
analysis = {
|
|
'overall_summary': {},
|
|
'significance_tests': {},
|
|
'confidence_intervals': {},
|
|
'effect_sizes': {}
|
|
}
|
|
|
|
# Analyze each model type
|
|
for model_type in ['retriever', 'reranker', 'pipeline']:
|
|
if model_type not in self.results:
|
|
continue
|
|
|
|
model_results = self.results[model_type]
|
|
|
|
# Collect all improvements across datasets
|
|
all_improvements = []
|
|
significant_improvements = 0
|
|
|
|
for dataset, comparison in model_results.get('comparisons', {}).items():
|
|
for metric, improvement_data in comparison.get('improvements', {}).items():
|
|
improvement_pct = improvement_data.get('improvement_pct', 0)
|
|
if improvement_pct > 0:
|
|
all_improvements.append(improvement_pct)
|
|
|
|
# Simple significance test (can be enhanced)
|
|
if improvement_pct > 5.0: # 5% improvement threshold
|
|
significant_improvements += 1
|
|
|
|
if all_improvements:
|
|
analysis[model_type] = {
|
|
'mean_improvement_pct': np.mean(all_improvements),
|
|
'median_improvement_pct': np.median(all_improvements),
|
|
'std_improvement_pct': np.std(all_improvements),
|
|
'min_improvement_pct': np.min(all_improvements),
|
|
'max_improvement_pct': np.max(all_improvements),
|
|
'significant_improvements': significant_improvements,
|
|
'total_improvements': len(all_improvements)
|
|
}
|
|
|
|
self.results['statistical_analysis'] = analysis
|
|
|
|
def _generate_report(self):
|
|
"""Generate comprehensive validation report"""
|
|
self.logger.info("📝 Generating Comprehensive Report...")
|
|
|
|
report_path = os.path.join(self.config.output_dir, "validation_report.html")
|
|
|
|
html_content = self._create_html_report()
|
|
|
|
with open(report_path, 'w', encoding='utf-8') as f:
|
|
f.write(html_content)
|
|
|
|
# Also save JSON results
|
|
json_path = os.path.join(self.config.output_dir, "validation_results.json")
|
|
with open(json_path, 'w', encoding='utf-8') as f:
|
|
json.dump(self.results, f, indent=2, default=str)
|
|
|
|
self.logger.info(f"📊 Report saved to: {report_path}")
|
|
self.logger.info(f"📊 JSON results saved to: {json_path}")
|
|
|
|
def _create_html_report(self) -> str:
|
|
"""Create HTML report with comprehensive validation results"""
|
|
html = f"""
|
|
<!DOCTYPE html>
|
|
<html>
|
|
<head>
|
|
<title>BGE Model Validation Report</title>
|
|
<style>
|
|
body {{ font-family: Arial, sans-serif; margin: 20px; }}
|
|
.header {{ background: #f0f8ff; padding: 20px; border-radius: 10px; }}
|
|
.section {{ margin: 20px 0; }}
|
|
.metrics-table {{ border-collapse: collapse; width: 100%; }}
|
|
.metrics-table th, .metrics-table td {{ border: 1px solid #ddd; padding: 8px; text-align: left; }}
|
|
.metrics-table th {{ background-color: #f2f2f2; }}
|
|
.improvement {{ color: green; font-weight: bold; }}
|
|
.degradation {{ color: red; font-weight: bold; }}
|
|
.summary-box {{ background: #f9f9f9; padding: 15px; border-left: 4px solid #4CAF50; margin: 10px 0; }}
|
|
</style>
|
|
</head>
|
|
<body>
|
|
<div class="header">
|
|
<h1>🚀 BGE Model Comprehensive Validation Report</h1>
|
|
<p><strong>Generated:</strong> {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}</p>
|
|
<p><strong>Configuration:</strong></p>
|
|
<ul>
|
|
<li>Retriever Baseline: {self.config.retriever_baseline}</li>
|
|
<li>Retriever Fine-tuned: {self.config.retriever_finetuned or 'Not tested'}</li>
|
|
<li>Reranker Baseline: {self.config.reranker_baseline}</li>
|
|
<li>Reranker Fine-tuned: {self.config.reranker_finetuned or 'Not tested'}</li>
|
|
</ul>
|
|
</div>
|
|
"""
|
|
|
|
# Add executive summary
|
|
html += self._create_executive_summary()
|
|
|
|
# Add detailed results for each model type
|
|
for model_type in ['retriever', 'reranker', 'pipeline']:
|
|
if model_type in self.results:
|
|
html += self._create_model_section(model_type)
|
|
|
|
html += """
|
|
</body>
|
|
</html>
|
|
"""
|
|
return html
|
|
|
|
def _create_executive_summary(self) -> str:
|
|
"""Create executive summary section"""
|
|
html = """
|
|
<div class="section">
|
|
<h2>📊 Executive Summary</h2>
|
|
"""
|
|
|
|
# Overall verdict
|
|
total_improvements = 0
|
|
total_degradations = 0
|
|
|
|
for model_type in ['retriever', 'reranker', 'pipeline']:
|
|
if model_type in self.results and 'comparisons' in self.results[model_type]:
|
|
for dataset, comparison in self.results[model_type]['comparisons'].items():
|
|
total_improvements += len(comparison.get('improvements', {}))
|
|
total_degradations += len(comparison.get('degradations', {}))
|
|
|
|
if total_improvements > total_degradations:
|
|
verdict = "✅ <strong>FINE-TUNING SUCCESSFUL</strong> - Models show overall improvement"
|
|
verdict_color = "green"
|
|
elif total_improvements == total_degradations:
|
|
verdict = "⚠️ <strong>MIXED RESULTS</strong> - Equal improvements and degradations"
|
|
verdict_color = "orange"
|
|
else:
|
|
verdict = "❌ <strong>FINE-TUNING ISSUES</strong> - More degradations than improvements"
|
|
verdict_color = "red"
|
|
|
|
html += f"""
|
|
<div class="summary-box" style="border-left-color: {verdict_color};">
|
|
<h3>{verdict}</h3>
|
|
<ul>
|
|
<li>Total Metric Improvements: {total_improvements}</li>
|
|
<li>Total Metric Degradations: {total_degradations}</li>
|
|
<li>Net Improvement: {total_improvements - total_degradations}</li>
|
|
</ul>
|
|
</div>
|
|
"""
|
|
|
|
html += """
|
|
</div>
|
|
"""
|
|
return html
|
|
|
|
def _create_model_section(self, model_type: str) -> str:
|
|
"""Create detailed section for a model type"""
|
|
html = f"""
|
|
<div class="section">
|
|
<h2>📈 {model_type.title()} Model Results</h2>
|
|
"""
|
|
|
|
model_results = self.results[model_type]
|
|
|
|
for dataset in model_results.get('datasets_tested', []):
|
|
if dataset in model_results['comparisons']:
|
|
comparison = model_results['comparisons'][dataset]
|
|
|
|
html += f"""
|
|
<h3>Dataset: {dataset}</h3>
|
|
<table class="metrics-table">
|
|
<thead>
|
|
<tr>
|
|
<th>Metric</th>
|
|
<th>Baseline</th>
|
|
<th>Fine-tuned</th>
|
|
<th>Change</th>
|
|
<th>Status</th>
|
|
</tr>
|
|
</thead>
|
|
<tbody>
|
|
"""
|
|
|
|
# Add improvements
|
|
for metric, data in comparison.get('improvements', {}).items():
|
|
html += f"""
|
|
<tr>
|
|
<td>{metric}</td>
|
|
<td>{data['baseline']:.4f}</td>
|
|
<td>{data['finetuned']:.4f}</td>
|
|
<td class="improvement">+{data['improvement_pct']:.2f}%</td>
|
|
<td class="improvement">✅ Improved</td>
|
|
</tr>
|
|
"""
|
|
|
|
# Add degradations
|
|
for metric, data in comparison.get('degradations', {}).items():
|
|
html += f"""
|
|
<tr>
|
|
<td>{metric}</td>
|
|
<td>{data['baseline']:.4f}</td>
|
|
<td>{data['finetuned']:.4f}</td>
|
|
<td class="degradation">-{data['degradation_pct']:.2f}%</td>
|
|
<td class="degradation">❌ Degraded</td>
|
|
</tr>
|
|
"""
|
|
|
|
html += """
|
|
</tbody>
|
|
</table>
|
|
"""
|
|
|
|
html += """
|
|
</div>
|
|
"""
|
|
return html
|
|
|
|
|
|
def parse_args():
|
|
"""Parse command line arguments"""
|
|
parser = argparse.ArgumentParser(
|
|
description="Comprehensive validation of BGE fine-tuned models",
|
|
formatter_class=argparse.RawDescriptionHelpFormatter,
|
|
epilog="""
|
|
Examples:
|
|
# Validate both retriever and reranker
|
|
python scripts/comprehensive_validation.py \\
|
|
--retriever_finetuned ./output/bge-m3-enhanced/final_model \\
|
|
--reranker_finetuned ./output/bge-reranker/final_model
|
|
|
|
# Validate only retriever
|
|
python scripts/comprehensive_validation.py \\
|
|
--retriever_finetuned ./output/bge-m3-enhanced/final_model
|
|
|
|
# Custom datasets and output
|
|
python scripts/comprehensive_validation.py \\
|
|
--retriever_finetuned ./output/bge-m3-enhanced/final_model \\
|
|
--test_datasets data/my_test1.jsonl data/my_test2.jsonl \\
|
|
--output_dir ./my_validation_results
|
|
"""
|
|
)
|
|
|
|
# Model paths
|
|
parser.add_argument("--retriever_baseline", type=str, default="BAAI/bge-m3",
|
|
help="Path to baseline retriever model")
|
|
parser.add_argument("--reranker_baseline", type=str, default="BAAI/bge-reranker-base",
|
|
help="Path to baseline reranker model")
|
|
parser.add_argument("--retriever_finetuned", type=str, default=None,
|
|
help="Path to fine-tuned retriever model")
|
|
parser.add_argument("--reranker_finetuned", type=str, default=None,
|
|
help="Path to fine-tuned reranker model")
|
|
|
|
# Test data
|
|
parser.add_argument("--test_datasets", type=str, nargs="+", default=None,
|
|
help="Paths to test datasets (JSONL format)")
|
|
|
|
# Evaluation settings
|
|
parser.add_argument("--batch_size", type=int, default=32,
|
|
help="Batch size for evaluation")
|
|
parser.add_argument("--max_length", type=int, default=512,
|
|
help="Maximum sequence length")
|
|
parser.add_argument("--k_values", type=int, nargs="+", default=[1, 3, 5, 10, 20],
|
|
help="K values for evaluation metrics")
|
|
|
|
# Output settings
|
|
parser.add_argument("--output_dir", type=str, default="./validation_results",
|
|
help="Directory to save validation results")
|
|
parser.add_argument("--save_embeddings", action="store_true",
|
|
help="Save embeddings for further analysis")
|
|
parser.add_argument("--no_report", action="store_true",
|
|
help="Skip HTML report generation")
|
|
|
|
# Pipeline settings
|
|
parser.add_argument("--retrieval_top_k", type=int, default=100,
|
|
help="Top-k for retrieval stage")
|
|
parser.add_argument("--rerank_top_k", type=int, default=10,
|
|
help="Top-k for reranking stage")
|
|
|
|
return parser.parse_args()
|
|
|
|
|
|
def main():
|
|
"""Main validation function"""
|
|
args = parse_args()
|
|
|
|
# Validate arguments
|
|
if not args.retriever_finetuned and not args.reranker_finetuned:
|
|
print("❌ Error: At least one fine-tuned model must be specified")
|
|
print(" Use --retriever_finetuned or --reranker_finetuned")
|
|
return 1
|
|
|
|
# Create configuration
|
|
config = ValidationConfig(
|
|
retriever_baseline=args.retriever_baseline,
|
|
reranker_baseline=args.reranker_baseline,
|
|
retriever_finetuned=args.retriever_finetuned,
|
|
reranker_finetuned=args.reranker_finetuned,
|
|
test_datasets=args.test_datasets,
|
|
batch_size=args.batch_size,
|
|
max_length=args.max_length,
|
|
k_values=args.k_values,
|
|
output_dir=args.output_dir,
|
|
save_embeddings=args.save_embeddings,
|
|
create_report=not args.no_report,
|
|
retrieval_top_k=args.retrieval_top_k,
|
|
rerank_top_k=args.rerank_top_k
|
|
)
|
|
|
|
# Run validation
|
|
try:
|
|
validator = ComprehensiveValidator(config)
|
|
results = validator.validate_all()
|
|
|
|
print("\n" + "=" * 80)
|
|
print("🎉 COMPREHENSIVE VALIDATION COMPLETED!")
|
|
print("=" * 80)
|
|
print(f"📊 Results saved to: {config.output_dir}")
|
|
print("\n📈 Quick Summary:")
|
|
|
|
# Print quick summary
|
|
for model_type in ['retriever', 'reranker', 'pipeline']:
|
|
if model_type in results:
|
|
model_results = results[model_type]
|
|
total_datasets = len(model_results.get('datasets_tested', []))
|
|
total_improvements = sum(len(comp.get('improvements', {}))
|
|
for comp in model_results.get('comparisons', {}).values())
|
|
total_degradations = sum(len(comp.get('degradations', {}))
|
|
for comp in model_results.get('comparisons', {}).values())
|
|
|
|
status = "✅" if total_improvements > total_degradations else "⚠️" if total_improvements == total_degradations else "❌"
|
|
print(f" {status} {model_type.title()}: {total_improvements} improvements, {total_degradations} degradations across {total_datasets} datasets")
|
|
|
|
return 0
|
|
|
|
except KeyboardInterrupt:
|
|
print("\n❌ Validation interrupted by user")
|
|
return 1
|
|
except Exception as e:
|
|
print(f"\n❌ Validation failed: {e}")
|
|
return 1
|
|
|
|
|
|
if __name__ == "__main__":
|
|
exit(main()) |