585 lines
21 KiB
Python
585 lines
21 KiB
Python
"""
|
|
Evaluation script for fine-tuned BGE models
|
|
"""
|
|
import os
|
|
import sys
|
|
import argparse
|
|
import logging
|
|
import json
|
|
import torch
|
|
from transformers import AutoTokenizer
|
|
import numpy as np
|
|
|
|
# Add parent directory to path
|
|
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
|
|
|
from models.bge_m3 import BGEM3Model
|
|
from models.bge_reranker import BGERerankerModel
|
|
from evaluation.evaluator import RetrievalEvaluator, InteractiveEvaluator
|
|
from data.preprocessing import DataPreprocessor
|
|
from utils.logging import setup_logging, get_logger
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def parse_args():
|
|
"""Parse command line arguments for evaluation."""
|
|
parser = argparse.ArgumentParser(description="Evaluate fine-tuned BGE models")
|
|
|
|
# Model arguments
|
|
parser.add_argument("--retriever_model", type=str, required=True,
|
|
help="Path to fine-tuned retriever model")
|
|
parser.add_argument("--reranker_model", type=str, default=None,
|
|
help="Path to fine-tuned reranker model")
|
|
parser.add_argument("--base_retriever", type=str, default="BAAI/bge-m3",
|
|
help="Base retriever model for comparison")
|
|
parser.add_argument("--base_reranker", type=str, default="BAAI/bge-reranker-base",
|
|
help="Base reranker model for comparison")
|
|
|
|
# Data arguments
|
|
parser.add_argument("--eval_data", type=str, required=True,
|
|
help="Path to evaluation data file")
|
|
parser.add_argument("--corpus_data", type=str, default=None,
|
|
help="Path to corpus data for retrieval evaluation")
|
|
parser.add_argument("--data_format", type=str, default="auto",
|
|
choices=["auto", "jsonl", "json", "csv", "tsv", "dureader", "msmarco"],
|
|
help="Format of input data")
|
|
|
|
# Evaluation arguments
|
|
parser.add_argument("--output_dir", type=str, default="./evaluation_results",
|
|
help="Directory to save evaluation results")
|
|
parser.add_argument("--batch_size", type=int, default=32,
|
|
help="Batch size for evaluation")
|
|
parser.add_argument("--max_length", type=int, default=512,
|
|
help="Maximum sequence length")
|
|
parser.add_argument("--k_values", type=int, nargs="+", default=[1, 5, 10, 20, 100],
|
|
help="K values for evaluation metrics")
|
|
|
|
# Evaluation modes
|
|
parser.add_argument("--eval_retriever", action="store_true",
|
|
help="Evaluate retriever model")
|
|
parser.add_argument("--eval_reranker", action="store_true",
|
|
help="Evaluate reranker model")
|
|
parser.add_argument("--eval_pipeline", action="store_true",
|
|
help="Evaluate full retrieval pipeline")
|
|
parser.add_argument("--compare_baseline", action="store_true",
|
|
help="Compare with baseline models")
|
|
parser.add_argument("--interactive", action="store_true",
|
|
help="Run interactive evaluation")
|
|
|
|
# Pipeline settings
|
|
parser.add_argument("--retrieval_top_k", type=int, default=100,
|
|
help="Top-k for initial retrieval")
|
|
parser.add_argument("--rerank_top_k", type=int, default=10,
|
|
help="Top-k for reranking")
|
|
|
|
# Other
|
|
parser.add_argument("--device", type=str, default="auto",
|
|
help="Device to use (auto, cpu, cuda)")
|
|
parser.add_argument("--seed", type=int, default=42,
|
|
help="Random seed")
|
|
|
|
args = parser.parse_args()
|
|
|
|
# Validation
|
|
if not any([args.eval_retriever, args.eval_reranker, args.eval_pipeline, args.interactive]):
|
|
args.eval_pipeline = True # Default to pipeline evaluation
|
|
|
|
return args
|
|
|
|
|
|
def load_evaluation_data(data_path: str, data_format: str):
|
|
"""Load evaluation data from a file."""
|
|
preprocessor = DataPreprocessor()
|
|
|
|
if data_format == "auto":
|
|
data_format = preprocessor._detect_format(data_path)
|
|
|
|
if data_format == "jsonl":
|
|
data = preprocessor._load_jsonl(data_path)
|
|
elif data_format == "json":
|
|
data = preprocessor._load_json(data_path)
|
|
elif data_format in ["csv", "tsv"]:
|
|
data = preprocessor._load_csv(data_path, delimiter="," if data_format == "csv" else "\t")
|
|
elif data_format == "dureader":
|
|
data = preprocessor._load_dureader(data_path)
|
|
elif data_format == "msmarco":
|
|
data = preprocessor._load_msmarco(data_path)
|
|
else:
|
|
raise ValueError(f"Unsupported format: {data_format}")
|
|
|
|
return data
|
|
|
|
|
|
def evaluate_retriever(args, retriever_model, tokenizer, eval_data):
|
|
"""Evaluate the retriever model."""
|
|
logger.info("Evaluating retriever model...")
|
|
|
|
evaluator = RetrievalEvaluator(
|
|
model=retriever_model,
|
|
tokenizer=tokenizer,
|
|
device=args.device,
|
|
k_values=args.k_values
|
|
)
|
|
|
|
# Prepare data for evaluation
|
|
queries = [item['query'] for item in eval_data]
|
|
|
|
# If we have passage data, create passage corpus
|
|
all_passages = set()
|
|
for item in eval_data:
|
|
if 'pos' in item:
|
|
all_passages.update(item['pos'])
|
|
if 'neg' in item:
|
|
all_passages.update(item['neg'])
|
|
|
|
passages = list(all_passages)
|
|
|
|
# Create relevance mapping
|
|
relevant_docs = {}
|
|
for i, item in enumerate(eval_data):
|
|
query = item['query']
|
|
if 'pos' in item:
|
|
relevant_indices = []
|
|
for pos in item['pos']:
|
|
if pos in passages:
|
|
relevant_indices.append(passages.index(pos))
|
|
if relevant_indices:
|
|
relevant_docs[query] = relevant_indices
|
|
|
|
# Evaluate
|
|
if passages and relevant_docs:
|
|
metrics = evaluator.evaluate_retrieval_pipeline(
|
|
queries=queries,
|
|
corpus=passages,
|
|
relevant_docs=relevant_docs,
|
|
batch_size=args.batch_size
|
|
)
|
|
else:
|
|
logger.warning("No valid corpus or relevance data found. Running basic evaluation.")
|
|
# Create dummy evaluation
|
|
metrics = {"note": "Limited evaluation due to data format"}
|
|
|
|
return metrics
|
|
|
|
|
|
def evaluate_reranker(args, reranker_model, tokenizer, eval_data):
|
|
"""Evaluate the reranker model."""
|
|
logger.info("Evaluating reranker model...")
|
|
|
|
evaluator = RetrievalEvaluator(
|
|
model=reranker_model,
|
|
tokenizer=tokenizer,
|
|
device=args.device,
|
|
k_values=args.k_values
|
|
)
|
|
|
|
# Prepare data for reranker evaluation
|
|
total_queries = 0
|
|
total_correct = 0
|
|
total_mrr = 0
|
|
|
|
for item in eval_data:
|
|
if 'pos' not in item or 'neg' not in item:
|
|
continue
|
|
|
|
query = item['query']
|
|
positives = item['pos']
|
|
negatives = item['neg']
|
|
|
|
if not positives or not negatives:
|
|
continue
|
|
|
|
# Create candidates (1 positive + negatives)
|
|
candidates = positives[:1] + negatives
|
|
|
|
# Score all candidates
|
|
pairs = [(query, candidate) for candidate in candidates]
|
|
scores = reranker_model.compute_score(pairs, batch_size=args.batch_size)
|
|
|
|
# Check if positive is ranked first
|
|
best_idx = np.argmax(scores)
|
|
if best_idx == 0: # First is positive
|
|
total_correct += 1
|
|
total_mrr += 1.0
|
|
else:
|
|
# Find rank of positive (it's always at index 0)
|
|
sorted_indices = np.argsort(scores)[::-1]
|
|
positive_rank = np.where(sorted_indices == 0)[0][0] + 1
|
|
total_mrr += 1.0 / positive_rank
|
|
|
|
total_queries += 1
|
|
|
|
if total_queries > 0:
|
|
metrics = {
|
|
'accuracy@1': total_correct / total_queries,
|
|
'mrr': total_mrr / total_queries,
|
|
'total_queries': total_queries
|
|
}
|
|
else:
|
|
metrics = {"note": "No valid query-passage pairs found for reranker evaluation"}
|
|
|
|
return metrics
|
|
|
|
|
|
def evaluate_pipeline(args, retriever_model, reranker_model, tokenizer, eval_data):
|
|
"""Evaluate the full retrieval + reranking pipeline."""
|
|
logger.info("Evaluating full pipeline...")
|
|
|
|
# Prepare corpus
|
|
all_passages = set()
|
|
for item in eval_data:
|
|
if 'pos' in item:
|
|
all_passages.update(item['pos'])
|
|
if 'neg' in item:
|
|
all_passages.update(item['neg'])
|
|
|
|
passages = list(all_passages)
|
|
|
|
if not passages:
|
|
return {"note": "No passages found for pipeline evaluation"}
|
|
|
|
# Pre-encode corpus with retriever
|
|
logger.info("Encoding corpus...")
|
|
corpus_embeddings = retriever_model.encode(
|
|
passages,
|
|
batch_size=args.batch_size,
|
|
return_numpy=True,
|
|
device=args.device
|
|
)
|
|
|
|
total_queries = 0
|
|
pipeline_metrics = {
|
|
'retrieval_recall@10': [],
|
|
'rerank_accuracy@1': [],
|
|
'rerank_mrr': []
|
|
}
|
|
|
|
for item in eval_data:
|
|
if 'pos' not in item or not item['pos']:
|
|
continue
|
|
|
|
query = item['query']
|
|
relevant_passages = item['pos']
|
|
|
|
# Step 1: Retrieval
|
|
query_embedding = retriever_model.encode(
|
|
[query],
|
|
return_numpy=True,
|
|
device=args.device
|
|
)
|
|
|
|
# Compute similarities
|
|
similarities = np.dot(query_embedding, corpus_embeddings.T).squeeze()
|
|
|
|
# Get top-k for retrieval
|
|
top_k_indices = np.argsort(similarities)[::-1][:args.retrieval_top_k]
|
|
retrieved_passages = [passages[i] for i in top_k_indices]
|
|
|
|
# Check retrieval recall
|
|
retrieved_relevant = sum(1 for p in retrieved_passages[:10] if p in relevant_passages)
|
|
retrieval_recall = retrieved_relevant / min(len(relevant_passages), 10)
|
|
pipeline_metrics['retrieval_recall@10'].append(retrieval_recall)
|
|
|
|
# Step 2: Reranking (if reranker available)
|
|
if reranker_model is not None:
|
|
# Rerank top candidates
|
|
rerank_candidates = retrieved_passages[:args.rerank_top_k]
|
|
pairs = [(query, candidate) for candidate in rerank_candidates]
|
|
|
|
if pairs:
|
|
rerank_scores = reranker_model.compute_score(pairs, batch_size=args.batch_size)
|
|
reranked_indices = np.argsort(rerank_scores)[::-1]
|
|
reranked_passages = [rerank_candidates[i] for i in reranked_indices]
|
|
|
|
# Check reranking performance
|
|
if reranked_passages[0] in relevant_passages:
|
|
pipeline_metrics['rerank_accuracy@1'].append(1.0)
|
|
pipeline_metrics['rerank_mrr'].append(1.0)
|
|
else:
|
|
pipeline_metrics['rerank_accuracy@1'].append(0.0)
|
|
# Find MRR
|
|
for rank, passage in enumerate(reranked_passages, 1):
|
|
if passage in relevant_passages:
|
|
pipeline_metrics['rerank_mrr'].append(1.0 / rank)
|
|
break
|
|
else:
|
|
pipeline_metrics['rerank_mrr'].append(0.0)
|
|
|
|
total_queries += 1
|
|
|
|
# Average metrics
|
|
final_metrics = {}
|
|
for metric, values in pipeline_metrics.items():
|
|
if values:
|
|
final_metrics[metric] = np.mean(values)
|
|
|
|
final_metrics['total_queries'] = total_queries
|
|
|
|
return final_metrics
|
|
|
|
|
|
def run_interactive_evaluation(args, retriever_model, reranker_model, tokenizer, eval_data):
|
|
"""Run interactive evaluation."""
|
|
logger.info("Starting interactive evaluation...")
|
|
|
|
# Prepare corpus
|
|
all_passages = set()
|
|
for item in eval_data:
|
|
if 'pos' in item:
|
|
all_passages.update(item['pos'])
|
|
if 'neg' in item:
|
|
all_passages.update(item['neg'])
|
|
|
|
passages = list(all_passages)
|
|
|
|
if not passages:
|
|
print("No passages found for interactive evaluation")
|
|
return
|
|
|
|
# Create interactive evaluator
|
|
evaluator = InteractiveEvaluator(
|
|
retriever=retriever_model,
|
|
reranker=reranker_model,
|
|
tokenizer=tokenizer,
|
|
corpus=passages,
|
|
device=args.device
|
|
)
|
|
|
|
print("\n=== Interactive Evaluation ===")
|
|
print("Enter queries to test the retrieval system.")
|
|
print("Type 'quit' to exit, 'sample' to test with sample queries.")
|
|
|
|
# Load some sample queries
|
|
sample_queries = [item['query'] for item in eval_data[:5]]
|
|
|
|
while True:
|
|
try:
|
|
user_input = input("\nEnter query: ").strip()
|
|
|
|
if user_input.lower() in ['quit', 'exit', 'q']:
|
|
break
|
|
elif user_input.lower() == 'sample':
|
|
print("\nSample queries:")
|
|
for i, query in enumerate(sample_queries):
|
|
print(f"{i + 1}. {query}")
|
|
continue
|
|
elif user_input.isdigit() and 1 <= int(user_input) <= len(sample_queries):
|
|
query = sample_queries[int(user_input) - 1]
|
|
else:
|
|
query = user_input
|
|
|
|
if not query:
|
|
continue
|
|
|
|
# Search
|
|
results = evaluator.search(query, top_k=10, rerank=reranker_model is not None)
|
|
|
|
print(f"\nResults for: {query}")
|
|
print("-" * 50)
|
|
for i, (idx, score, text) in enumerate(results):
|
|
print(f"{i + 1}. (Score: {score:.3f}) {text[:100]}...")
|
|
|
|
except KeyboardInterrupt:
|
|
print("\nExiting interactive evaluation...")
|
|
break
|
|
except Exception as e:
|
|
print(f"Error: {e}")
|
|
|
|
|
|
def main():
|
|
"""Run evaluation with robust error handling and config-driven defaults."""
|
|
import argparse
|
|
parser = argparse.ArgumentParser(description="Evaluate BGE models.")
|
|
parser.add_argument('--device', type=str, default=None, help='Device to use (cuda, npu, cpu)')
|
|
|
|
# Model arguments
|
|
parser.add_argument("--retriever_model", type=str, required=True,
|
|
help="Path to fine-tuned retriever model")
|
|
parser.add_argument("--reranker_model", type=str, default=None,
|
|
help="Path to fine-tuned reranker model")
|
|
parser.add_argument("--base_retriever", type=str, default="BAAI/bge-m3",
|
|
help="Base retriever model for comparison")
|
|
parser.add_argument("--base_reranker", type=str, default="BAAI/bge-reranker-base",
|
|
help="Base reranker model for comparison")
|
|
|
|
# Data arguments
|
|
parser.add_argument("--eval_data", type=str, required=True,
|
|
help="Path to evaluation data file")
|
|
parser.add_argument("--corpus_data", type=str, default=None,
|
|
help="Path to corpus data for retrieval evaluation")
|
|
parser.add_argument("--data_format", type=str, default="auto",
|
|
choices=["auto", "jsonl", "json", "csv", "tsv", "dureader", "msmarco"],
|
|
help="Format of input data")
|
|
|
|
# Evaluation arguments
|
|
parser.add_argument("--output_dir", type=str, default="./evaluation_results",
|
|
help="Directory to save evaluation results")
|
|
parser.add_argument("--batch_size", type=int, default=32,
|
|
help="Batch size for evaluation")
|
|
parser.add_argument("--max_length", type=int, default=512,
|
|
help="Maximum sequence length")
|
|
parser.add_argument("--k_values", type=int, nargs="+", default=[1, 5, 10, 20, 100],
|
|
help="K values for evaluation metrics")
|
|
|
|
# Evaluation modes
|
|
parser.add_argument("--eval_retriever", action="store_true",
|
|
help="Evaluate retriever model")
|
|
parser.add_argument("--eval_reranker", action="store_true",
|
|
help="Evaluate reranker model")
|
|
parser.add_argument("--eval_pipeline", action="store_true",
|
|
help="Evaluate full retrieval pipeline")
|
|
parser.add_argument("--compare_baseline", action="store_true",
|
|
help="Compare with baseline models")
|
|
parser.add_argument("--interactive", action="store_true",
|
|
help="Run interactive evaluation")
|
|
|
|
# Pipeline settings
|
|
parser.add_argument("--retrieval_top_k", type=int, default=100,
|
|
help="Top-k for initial retrieval")
|
|
parser.add_argument("--rerank_top_k", type=int, default=10,
|
|
help="Top-k for reranking")
|
|
|
|
# Other
|
|
parser.add_argument("--seed", type=int, default=42,
|
|
help="Random seed")
|
|
|
|
args = parser.parse_args()
|
|
|
|
# Load config and set defaults
|
|
from utils.config_loader import get_config
|
|
config = get_config()
|
|
device = args.device or config.get('hardware', {}).get('device', 'cuda') # type: ignore[attr-defined]
|
|
|
|
# Setup logging
|
|
setup_logging(
|
|
output_dir=args.output_dir,
|
|
log_level=logging.INFO,
|
|
log_to_console=True,
|
|
log_to_file=True
|
|
)
|
|
|
|
logger = get_logger(__name__)
|
|
logger.info("Starting model evaluation")
|
|
logger.info(f"Arguments: {args}")
|
|
logger.info(f"Using device: {device}")
|
|
|
|
# Set device
|
|
if device == "auto":
|
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
elif device == "npu":
|
|
npu = getattr(torch, "npu", None)
|
|
if npu is not None and hasattr(npu, "is_available") and npu.is_available():
|
|
device = "npu"
|
|
else:
|
|
logger.warning("NPU (Ascend) is not available. Falling back to CPU.")
|
|
device = "cpu"
|
|
else:
|
|
device = device
|
|
|
|
logger.info(f"Using device: {device}")
|
|
|
|
# Load models
|
|
logger.info("Loading models...")
|
|
from utils.config_loader import get_config
|
|
model_source = get_config().get('model_paths.source', 'huggingface')
|
|
# Load retriever
|
|
try:
|
|
retriever_model = BGEM3Model.from_pretrained(args.retriever_model)
|
|
retriever_tokenizer = None # Already loaded inside BGEM3Model if needed
|
|
retriever_model.to(device)
|
|
retriever_model.eval()
|
|
logger.info(f"Loaded retriever from {args.retriever_model} (source: {model_source})")
|
|
except Exception as e:
|
|
logger.error(f"Failed to load retriever: {e}")
|
|
return
|
|
|
|
# Load reranker (optional)
|
|
reranker_model = None
|
|
reranker_tokenizer = None
|
|
if args.reranker_model:
|
|
try:
|
|
reranker_model = BGERerankerModel(args.reranker_model)
|
|
reranker_tokenizer = None # Already loaded inside BGERerankerModel if needed
|
|
reranker_model.to(device)
|
|
reranker_model.eval()
|
|
logger.info(f"Loaded reranker from {args.reranker_model} (source: {model_source})")
|
|
except Exception as e:
|
|
logger.warning(f"Failed to load reranker: {e}")
|
|
|
|
# Load evaluation data
|
|
logger.info("Loading evaluation data...")
|
|
eval_data = load_evaluation_data(args.eval_data, args.data_format)
|
|
logger.info(f"Loaded {len(eval_data)} evaluation examples")
|
|
|
|
# Create output directory
|
|
os.makedirs(args.output_dir, exist_ok=True)
|
|
|
|
# Run evaluations
|
|
all_results = {}
|
|
|
|
if args.eval_retriever:
|
|
retriever_metrics = evaluate_retriever(args, retriever_model, retriever_tokenizer, eval_data)
|
|
all_results['retriever'] = retriever_metrics
|
|
logger.info(f"Retriever metrics: {retriever_metrics}")
|
|
|
|
if args.eval_reranker and reranker_model:
|
|
reranker_metrics = evaluate_reranker(args, reranker_model, reranker_tokenizer, eval_data)
|
|
all_results['reranker'] = reranker_metrics
|
|
logger.info(f"Reranker metrics: {reranker_metrics}")
|
|
|
|
if args.eval_pipeline:
|
|
pipeline_metrics = evaluate_pipeline(args, retriever_model, reranker_model, retriever_tokenizer, eval_data)
|
|
all_results['pipeline'] = pipeline_metrics
|
|
logger.info(f"Pipeline metrics: {pipeline_metrics}")
|
|
|
|
# Compare with baseline if requested
|
|
if args.compare_baseline:
|
|
logger.info("Loading baseline models for comparison...")
|
|
try:
|
|
base_retriever = BGEM3Model(args.base_retriever)
|
|
base_retriever.to(device)
|
|
base_retriever.eval()
|
|
|
|
base_retriever_metrics = evaluate_retriever(args, base_retriever, retriever_tokenizer, eval_data)
|
|
all_results['baseline_retriever'] = base_retriever_metrics
|
|
logger.info(f"Baseline retriever metrics: {base_retriever_metrics}")
|
|
|
|
if reranker_model:
|
|
base_reranker = BGERerankerModel(args.base_reranker)
|
|
base_reranker.to(device)
|
|
base_reranker.eval()
|
|
|
|
base_reranker_metrics = evaluate_reranker(args, base_reranker, reranker_tokenizer, eval_data)
|
|
all_results['baseline_reranker'] = base_reranker_metrics
|
|
logger.info(f"Baseline reranker metrics: {base_reranker_metrics}")
|
|
except Exception as e:
|
|
logger.warning(f"Failed to load baseline models: {e}")
|
|
|
|
# Save results
|
|
results_file = os.path.join(args.output_dir, "evaluation_results.json")
|
|
with open(results_file, 'w') as f:
|
|
json.dump(all_results, f, indent=2)
|
|
|
|
logger.info(f"Evaluation results saved to {results_file}")
|
|
|
|
# Print summary
|
|
print("\n" + "=" * 50)
|
|
print("EVALUATION SUMMARY")
|
|
print("=" * 50)
|
|
for model_type, metrics in all_results.items():
|
|
print(f"\n{model_type.upper()}:")
|
|
for metric, value in metrics.items():
|
|
if isinstance(value, (int, float)):
|
|
print(f" {metric}: {value:.4f}")
|
|
else:
|
|
print(f" {metric}: {value}")
|
|
|
|
# Interactive evaluation
|
|
if args.interactive:
|
|
run_interactive_evaluation(args, retriever_model, reranker_model, retriever_tokenizer, eval_data)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|