bge_finetune/scripts/evaluate.py
2025-07-22 16:55:25 +08:00

585 lines
21 KiB
Python

"""
Evaluation script for fine-tuned BGE models
"""
import os
import sys
import argparse
import logging
import json
import torch
from transformers import AutoTokenizer
import numpy as np
# Add parent directory to path
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from models.bge_m3 import BGEM3Model
from models.bge_reranker import BGERerankerModel
from evaluation.evaluator import RetrievalEvaluator, InteractiveEvaluator
from data.preprocessing import DataPreprocessor
from utils.logging import setup_logging, get_logger
logger = logging.getLogger(__name__)
def parse_args():
"""Parse command line arguments for evaluation."""
parser = argparse.ArgumentParser(description="Evaluate fine-tuned BGE models")
# Model arguments
parser.add_argument("--retriever_model", type=str, required=True,
help="Path to fine-tuned retriever model")
parser.add_argument("--reranker_model", type=str, default=None,
help="Path to fine-tuned reranker model")
parser.add_argument("--base_retriever", type=str, default="BAAI/bge-m3",
help="Base retriever model for comparison")
parser.add_argument("--base_reranker", type=str, default="BAAI/bge-reranker-base",
help="Base reranker model for comparison")
# Data arguments
parser.add_argument("--eval_data", type=str, required=True,
help="Path to evaluation data file")
parser.add_argument("--corpus_data", type=str, default=None,
help="Path to corpus data for retrieval evaluation")
parser.add_argument("--data_format", type=str, default="auto",
choices=["auto", "jsonl", "json", "csv", "tsv", "dureader", "msmarco"],
help="Format of input data")
# Evaluation arguments
parser.add_argument("--output_dir", type=str, default="./evaluation_results",
help="Directory to save evaluation results")
parser.add_argument("--batch_size", type=int, default=32,
help="Batch size for evaluation")
parser.add_argument("--max_length", type=int, default=512,
help="Maximum sequence length")
parser.add_argument("--k_values", type=int, nargs="+", default=[1, 5, 10, 20, 100],
help="K values for evaluation metrics")
# Evaluation modes
parser.add_argument("--eval_retriever", action="store_true",
help="Evaluate retriever model")
parser.add_argument("--eval_reranker", action="store_true",
help="Evaluate reranker model")
parser.add_argument("--eval_pipeline", action="store_true",
help="Evaluate full retrieval pipeline")
parser.add_argument("--compare_baseline", action="store_true",
help="Compare with baseline models")
parser.add_argument("--interactive", action="store_true",
help="Run interactive evaluation")
# Pipeline settings
parser.add_argument("--retrieval_top_k", type=int, default=100,
help="Top-k for initial retrieval")
parser.add_argument("--rerank_top_k", type=int, default=10,
help="Top-k for reranking")
# Other
parser.add_argument("--device", type=str, default="auto",
help="Device to use (auto, cpu, cuda)")
parser.add_argument("--seed", type=int, default=42,
help="Random seed")
args = parser.parse_args()
# Validation
if not any([args.eval_retriever, args.eval_reranker, args.eval_pipeline, args.interactive]):
args.eval_pipeline = True # Default to pipeline evaluation
return args
def load_evaluation_data(data_path: str, data_format: str):
"""Load evaluation data from a file."""
preprocessor = DataPreprocessor()
if data_format == "auto":
data_format = preprocessor._detect_format(data_path)
if data_format == "jsonl":
data = preprocessor._load_jsonl(data_path)
elif data_format == "json":
data = preprocessor._load_json(data_path)
elif data_format in ["csv", "tsv"]:
data = preprocessor._load_csv(data_path, delimiter="," if data_format == "csv" else "\t")
elif data_format == "dureader":
data = preprocessor._load_dureader(data_path)
elif data_format == "msmarco":
data = preprocessor._load_msmarco(data_path)
else:
raise ValueError(f"Unsupported format: {data_format}")
return data
def evaluate_retriever(args, retriever_model, tokenizer, eval_data):
"""Evaluate the retriever model."""
logger.info("Evaluating retriever model...")
evaluator = RetrievalEvaluator(
model=retriever_model,
tokenizer=tokenizer,
device=args.device,
k_values=args.k_values
)
# Prepare data for evaluation
queries = [item['query'] for item in eval_data]
# If we have passage data, create passage corpus
all_passages = set()
for item in eval_data:
if 'pos' in item:
all_passages.update(item['pos'])
if 'neg' in item:
all_passages.update(item['neg'])
passages = list(all_passages)
# Create relevance mapping
relevant_docs = {}
for i, item in enumerate(eval_data):
query = item['query']
if 'pos' in item:
relevant_indices = []
for pos in item['pos']:
if pos in passages:
relevant_indices.append(passages.index(pos))
if relevant_indices:
relevant_docs[query] = relevant_indices
# Evaluate
if passages and relevant_docs:
metrics = evaluator.evaluate_retrieval_pipeline(
queries=queries,
corpus=passages,
relevant_docs=relevant_docs,
batch_size=args.batch_size
)
else:
logger.warning("No valid corpus or relevance data found. Running basic evaluation.")
# Create dummy evaluation
metrics = {"note": "Limited evaluation due to data format"}
return metrics
def evaluate_reranker(args, reranker_model, tokenizer, eval_data):
"""Evaluate the reranker model."""
logger.info("Evaluating reranker model...")
evaluator = RetrievalEvaluator(
model=reranker_model,
tokenizer=tokenizer,
device=args.device,
k_values=args.k_values
)
# Prepare data for reranker evaluation
total_queries = 0
total_correct = 0
total_mrr = 0
for item in eval_data:
if 'pos' not in item or 'neg' not in item:
continue
query = item['query']
positives = item['pos']
negatives = item['neg']
if not positives or not negatives:
continue
# Create candidates (1 positive + negatives)
candidates = positives[:1] + negatives
# Score all candidates
pairs = [(query, candidate) for candidate in candidates]
scores = reranker_model.compute_score(pairs, batch_size=args.batch_size)
# Check if positive is ranked first
best_idx = np.argmax(scores)
if best_idx == 0: # First is positive
total_correct += 1
total_mrr += 1.0
else:
# Find rank of positive (it's always at index 0)
sorted_indices = np.argsort(scores)[::-1]
positive_rank = np.where(sorted_indices == 0)[0][0] + 1
total_mrr += 1.0 / positive_rank
total_queries += 1
if total_queries > 0:
metrics = {
'accuracy@1': total_correct / total_queries,
'mrr': total_mrr / total_queries,
'total_queries': total_queries
}
else:
metrics = {"note": "No valid query-passage pairs found for reranker evaluation"}
return metrics
def evaluate_pipeline(args, retriever_model, reranker_model, tokenizer, eval_data):
"""Evaluate the full retrieval + reranking pipeline."""
logger.info("Evaluating full pipeline...")
# Prepare corpus
all_passages = set()
for item in eval_data:
if 'pos' in item:
all_passages.update(item['pos'])
if 'neg' in item:
all_passages.update(item['neg'])
passages = list(all_passages)
if not passages:
return {"note": "No passages found for pipeline evaluation"}
# Pre-encode corpus with retriever
logger.info("Encoding corpus...")
corpus_embeddings = retriever_model.encode(
passages,
batch_size=args.batch_size,
return_numpy=True,
device=args.device
)
total_queries = 0
pipeline_metrics = {
'retrieval_recall@10': [],
'rerank_accuracy@1': [],
'rerank_mrr': []
}
for item in eval_data:
if 'pos' not in item or not item['pos']:
continue
query = item['query']
relevant_passages = item['pos']
# Step 1: Retrieval
query_embedding = retriever_model.encode(
[query],
return_numpy=True,
device=args.device
)
# Compute similarities
similarities = np.dot(query_embedding, corpus_embeddings.T).squeeze()
# Get top-k for retrieval
top_k_indices = np.argsort(similarities)[::-1][:args.retrieval_top_k]
retrieved_passages = [passages[i] for i in top_k_indices]
# Check retrieval recall
retrieved_relevant = sum(1 for p in retrieved_passages[:10] if p in relevant_passages)
retrieval_recall = retrieved_relevant / min(len(relevant_passages), 10)
pipeline_metrics['retrieval_recall@10'].append(retrieval_recall)
# Step 2: Reranking (if reranker available)
if reranker_model is not None:
# Rerank top candidates
rerank_candidates = retrieved_passages[:args.rerank_top_k]
pairs = [(query, candidate) for candidate in rerank_candidates]
if pairs:
rerank_scores = reranker_model.compute_score(pairs, batch_size=args.batch_size)
reranked_indices = np.argsort(rerank_scores)[::-1]
reranked_passages = [rerank_candidates[i] for i in reranked_indices]
# Check reranking performance
if reranked_passages[0] in relevant_passages:
pipeline_metrics['rerank_accuracy@1'].append(1.0)
pipeline_metrics['rerank_mrr'].append(1.0)
else:
pipeline_metrics['rerank_accuracy@1'].append(0.0)
# Find MRR
for rank, passage in enumerate(reranked_passages, 1):
if passage in relevant_passages:
pipeline_metrics['rerank_mrr'].append(1.0 / rank)
break
else:
pipeline_metrics['rerank_mrr'].append(0.0)
total_queries += 1
# Average metrics
final_metrics = {}
for metric, values in pipeline_metrics.items():
if values:
final_metrics[metric] = np.mean(values)
final_metrics['total_queries'] = total_queries
return final_metrics
def run_interactive_evaluation(args, retriever_model, reranker_model, tokenizer, eval_data):
"""Run interactive evaluation."""
logger.info("Starting interactive evaluation...")
# Prepare corpus
all_passages = set()
for item in eval_data:
if 'pos' in item:
all_passages.update(item['pos'])
if 'neg' in item:
all_passages.update(item['neg'])
passages = list(all_passages)
if not passages:
print("No passages found for interactive evaluation")
return
# Create interactive evaluator
evaluator = InteractiveEvaluator(
retriever=retriever_model,
reranker=reranker_model,
tokenizer=tokenizer,
corpus=passages,
device=args.device
)
print("\n=== Interactive Evaluation ===")
print("Enter queries to test the retrieval system.")
print("Type 'quit' to exit, 'sample' to test with sample queries.")
# Load some sample queries
sample_queries = [item['query'] for item in eval_data[:5]]
while True:
try:
user_input = input("\nEnter query: ").strip()
if user_input.lower() in ['quit', 'exit', 'q']:
break
elif user_input.lower() == 'sample':
print("\nSample queries:")
for i, query in enumerate(sample_queries):
print(f"{i + 1}. {query}")
continue
elif user_input.isdigit() and 1 <= int(user_input) <= len(sample_queries):
query = sample_queries[int(user_input) - 1]
else:
query = user_input
if not query:
continue
# Search
results = evaluator.search(query, top_k=10, rerank=reranker_model is not None)
print(f"\nResults for: {query}")
print("-" * 50)
for i, (idx, score, text) in enumerate(results):
print(f"{i + 1}. (Score: {score:.3f}) {text[:100]}...")
except KeyboardInterrupt:
print("\nExiting interactive evaluation...")
break
except Exception as e:
print(f"Error: {e}")
def main():
"""Run evaluation with robust error handling and config-driven defaults."""
import argparse
parser = argparse.ArgumentParser(description="Evaluate BGE models.")
parser.add_argument('--device', type=str, default=None, help='Device to use (cuda, npu, cpu)')
# Model arguments
parser.add_argument("--retriever_model", type=str, required=True,
help="Path to fine-tuned retriever model")
parser.add_argument("--reranker_model", type=str, default=None,
help="Path to fine-tuned reranker model")
parser.add_argument("--base_retriever", type=str, default="BAAI/bge-m3",
help="Base retriever model for comparison")
parser.add_argument("--base_reranker", type=str, default="BAAI/bge-reranker-base",
help="Base reranker model for comparison")
# Data arguments
parser.add_argument("--eval_data", type=str, required=True,
help="Path to evaluation data file")
parser.add_argument("--corpus_data", type=str, default=None,
help="Path to corpus data for retrieval evaluation")
parser.add_argument("--data_format", type=str, default="auto",
choices=["auto", "jsonl", "json", "csv", "tsv", "dureader", "msmarco"],
help="Format of input data")
# Evaluation arguments
parser.add_argument("--output_dir", type=str, default="./evaluation_results",
help="Directory to save evaluation results")
parser.add_argument("--batch_size", type=int, default=32,
help="Batch size for evaluation")
parser.add_argument("--max_length", type=int, default=512,
help="Maximum sequence length")
parser.add_argument("--k_values", type=int, nargs="+", default=[1, 5, 10, 20, 100],
help="K values for evaluation metrics")
# Evaluation modes
parser.add_argument("--eval_retriever", action="store_true",
help="Evaluate retriever model")
parser.add_argument("--eval_reranker", action="store_true",
help="Evaluate reranker model")
parser.add_argument("--eval_pipeline", action="store_true",
help="Evaluate full retrieval pipeline")
parser.add_argument("--compare_baseline", action="store_true",
help="Compare with baseline models")
parser.add_argument("--interactive", action="store_true",
help="Run interactive evaluation")
# Pipeline settings
parser.add_argument("--retrieval_top_k", type=int, default=100,
help="Top-k for initial retrieval")
parser.add_argument("--rerank_top_k", type=int, default=10,
help="Top-k for reranking")
# Other
parser.add_argument("--seed", type=int, default=42,
help="Random seed")
args = parser.parse_args()
# Load config and set defaults
from utils.config_loader import get_config
config = get_config()
device = args.device or config.get('hardware', {}).get('device', 'cuda') # type: ignore[attr-defined]
# Setup logging
setup_logging(
output_dir=args.output_dir,
log_level=logging.INFO,
log_to_console=True,
log_to_file=True
)
logger = get_logger(__name__)
logger.info("Starting model evaluation")
logger.info(f"Arguments: {args}")
logger.info(f"Using device: {device}")
# Set device
if device == "auto":
device = "cuda" if torch.cuda.is_available() else "cpu"
elif device == "npu":
npu = getattr(torch, "npu", None)
if npu is not None and hasattr(npu, "is_available") and npu.is_available():
device = "npu"
else:
logger.warning("NPU (Ascend) is not available. Falling back to CPU.")
device = "cpu"
else:
device = device
logger.info(f"Using device: {device}")
# Load models
logger.info("Loading models...")
from utils.config_loader import get_config
model_source = get_config().get('model_paths.source', 'huggingface')
# Load retriever
try:
retriever_model = BGEM3Model.from_pretrained(args.retriever_model)
retriever_tokenizer = None # Already loaded inside BGEM3Model if needed
retriever_model.to(device)
retriever_model.eval()
logger.info(f"Loaded retriever from {args.retriever_model} (source: {model_source})")
except Exception as e:
logger.error(f"Failed to load retriever: {e}")
return
# Load reranker (optional)
reranker_model = None
reranker_tokenizer = None
if args.reranker_model:
try:
reranker_model = BGERerankerModel(args.reranker_model)
reranker_tokenizer = None # Already loaded inside BGERerankerModel if needed
reranker_model.to(device)
reranker_model.eval()
logger.info(f"Loaded reranker from {args.reranker_model} (source: {model_source})")
except Exception as e:
logger.warning(f"Failed to load reranker: {e}")
# Load evaluation data
logger.info("Loading evaluation data...")
eval_data = load_evaluation_data(args.eval_data, args.data_format)
logger.info(f"Loaded {len(eval_data)} evaluation examples")
# Create output directory
os.makedirs(args.output_dir, exist_ok=True)
# Run evaluations
all_results = {}
if args.eval_retriever:
retriever_metrics = evaluate_retriever(args, retriever_model, retriever_tokenizer, eval_data)
all_results['retriever'] = retriever_metrics
logger.info(f"Retriever metrics: {retriever_metrics}")
if args.eval_reranker and reranker_model:
reranker_metrics = evaluate_reranker(args, reranker_model, reranker_tokenizer, eval_data)
all_results['reranker'] = reranker_metrics
logger.info(f"Reranker metrics: {reranker_metrics}")
if args.eval_pipeline:
pipeline_metrics = evaluate_pipeline(args, retriever_model, reranker_model, retriever_tokenizer, eval_data)
all_results['pipeline'] = pipeline_metrics
logger.info(f"Pipeline metrics: {pipeline_metrics}")
# Compare with baseline if requested
if args.compare_baseline:
logger.info("Loading baseline models for comparison...")
try:
base_retriever = BGEM3Model(args.base_retriever)
base_retriever.to(device)
base_retriever.eval()
base_retriever_metrics = evaluate_retriever(args, base_retriever, retriever_tokenizer, eval_data)
all_results['baseline_retriever'] = base_retriever_metrics
logger.info(f"Baseline retriever metrics: {base_retriever_metrics}")
if reranker_model:
base_reranker = BGERerankerModel(args.base_reranker)
base_reranker.to(device)
base_reranker.eval()
base_reranker_metrics = evaluate_reranker(args, base_reranker, reranker_tokenizer, eval_data)
all_results['baseline_reranker'] = base_reranker_metrics
logger.info(f"Baseline reranker metrics: {base_reranker_metrics}")
except Exception as e:
logger.warning(f"Failed to load baseline models: {e}")
# Save results
results_file = os.path.join(args.output_dir, "evaluation_results.json")
with open(results_file, 'w') as f:
json.dump(all_results, f, indent=2)
logger.info(f"Evaluation results saved to {results_file}")
# Print summary
print("\n" + "=" * 50)
print("EVALUATION SUMMARY")
print("=" * 50)
for model_type, metrics in all_results.items():
print(f"\n{model_type.upper()}:")
for metric, value in metrics.items():
if isinstance(value, (int, float)):
print(f" {metric}: {value:.4f}")
else:
print(f" {metric}: {value}")
# Interactive evaluation
if args.interactive:
run_interactive_evaluation(args, retriever_model, reranker_model, retriever_tokenizer, eval_data)
if __name__ == "__main__":
main()