init commit
This commit is contained in:
511
evaluation/evaluator.py
Normal file
511
evaluation/evaluator.py
Normal file
@@ -0,0 +1,511 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
from typing import List, Dict, Optional, Tuple, Union
|
||||
from tqdm import tqdm
|
||||
import logging
|
||||
from torch.utils.data import DataLoader
|
||||
import json
|
||||
import os
|
||||
from collections import defaultdict
|
||||
|
||||
from models.bge_m3 import BGEM3Model
|
||||
from models.bge_reranker import BGERerankerModel
|
||||
from .metrics import (
|
||||
compute_retrieval_metrics,
|
||||
compute_reranker_metrics,
|
||||
evaluate_cross_lingual_retrieval,
|
||||
MetricsTracker
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RetrievalEvaluator:
|
||||
"""
|
||||
Comprehensive evaluator for retrieval models
|
||||
|
||||
Notes:
|
||||
- All parameters (model, tokenizer, device, metrics, k_values, etc.) should be passed from config-driven scripts.
|
||||
- This class does not load config directly; pass config values from your main script for consistency.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: Union[BGEM3Model, BGERerankerModel],
|
||||
tokenizer,
|
||||
device: str = 'cuda' if torch.cuda.is_available() else 'cpu',
|
||||
metrics: List[str] = ['recall', 'precision', 'map', 'mrr', 'ndcg'],
|
||||
k_values: List[int] = [1, 5, 10, 20, 100]
|
||||
):
|
||||
self.model = model
|
||||
self.tokenizer = tokenizer
|
||||
self.device = device
|
||||
self.metrics = metrics
|
||||
self.k_values = k_values
|
||||
|
||||
# Move model to device
|
||||
self.model.to(device)
|
||||
self.model.eval()
|
||||
|
||||
# Metrics tracker
|
||||
self.tracker = MetricsTracker()
|
||||
|
||||
def evaluate(
|
||||
self,
|
||||
dataloader: DataLoader,
|
||||
desc: str = "Evaluating",
|
||||
return_embeddings: bool = False
|
||||
) -> Dict[str, Union[float, list, np.ndarray]]:
|
||||
"""
|
||||
Evaluate model on dataset
|
||||
|
||||
Args:
|
||||
dataloader: DataLoader with evaluation data
|
||||
desc: Description for progress bar
|
||||
return_embeddings: Whether to return computed embeddings
|
||||
|
||||
Returns:
|
||||
Dictionary of evaluation metrics
|
||||
"""
|
||||
self.model.eval()
|
||||
|
||||
all_query_embeddings = []
|
||||
all_passage_embeddings = []
|
||||
all_labels = []
|
||||
all_queries = []
|
||||
all_passages = []
|
||||
|
||||
with torch.no_grad():
|
||||
for batch in tqdm(dataloader, desc=desc):
|
||||
# Move batch to device
|
||||
batch = {k: v.to(self.device) if isinstance(v, torch.Tensor) else v
|
||||
for k, v in batch.items()}
|
||||
|
||||
# Get embeddings based on model type
|
||||
if isinstance(self.model, BGEM3Model):
|
||||
# Bi-encoder evaluation
|
||||
query_embeddings, passage_embeddings = self._evaluate_biencoder_batch(batch)
|
||||
|
||||
all_query_embeddings.append(query_embeddings.cpu())
|
||||
all_passage_embeddings.append(passage_embeddings.cpu())
|
||||
|
||||
else:
|
||||
# Cross-encoder evaluation
|
||||
scores = self._evaluate_crossencoder_batch(batch)
|
||||
# For cross-encoder, we store scores instead of embeddings
|
||||
all_query_embeddings.append(scores.cpu())
|
||||
|
||||
# Collect labels
|
||||
if 'labels' in batch:
|
||||
all_labels.append(batch['labels'].cpu())
|
||||
|
||||
# Collect text if available
|
||||
if 'query' in batch:
|
||||
all_queries.extend(batch['query'])
|
||||
if 'passages' in batch:
|
||||
all_passages.extend(batch['passages'])
|
||||
|
||||
# Concatenate results
|
||||
if isinstance(self.model, BGEM3Model):
|
||||
all_query_embeddings = torch.cat(all_query_embeddings, dim=0).numpy()
|
||||
all_passage_embeddings = torch.cat(all_passage_embeddings, dim=0).numpy()
|
||||
all_labels = torch.cat(all_labels, dim=0).numpy() if all_labels else None
|
||||
|
||||
# Compute retrieval metrics
|
||||
if all_labels is not None:
|
||||
metrics = compute_retrieval_metrics(
|
||||
all_query_embeddings,
|
||||
all_passage_embeddings,
|
||||
all_labels,
|
||||
k_values=self.k_values
|
||||
)
|
||||
else:
|
||||
metrics = {}
|
||||
|
||||
else:
|
||||
# Cross-encoder metrics
|
||||
all_scores = torch.cat(all_query_embeddings, dim=0).numpy()
|
||||
all_labels = torch.cat(all_labels, dim=0).numpy() if all_labels else None
|
||||
|
||||
if all_labels is not None:
|
||||
# Determine if we have groups (listwise) or pairs
|
||||
if hasattr(dataloader.dataset, 'train_group_size'):
|
||||
groups = [getattr(dataloader.dataset, 'train_group_size')] * len(dataloader.dataset) # type: ignore[arg-type]
|
||||
else:
|
||||
groups = None
|
||||
|
||||
metrics = compute_reranker_metrics(
|
||||
all_scores,
|
||||
all_labels,
|
||||
groups=groups,
|
||||
k_values=self.k_values
|
||||
)
|
||||
else:
|
||||
metrics = {}
|
||||
|
||||
# Update tracker
|
||||
self.tracker.update(metrics, step=0)
|
||||
|
||||
# Return embeddings if requested
|
||||
if return_embeddings and isinstance(self.model, BGEM3Model):
|
||||
metrics['query_embeddings'] = all_query_embeddings # type: ignore[assignment]
|
||||
metrics['passage_embeddings'] = all_passage_embeddings # type: ignore[assignment]
|
||||
metrics['queries'] = all_queries # type: ignore[assignment]
|
||||
metrics['passages'] = all_passages # type: ignore[assignment]
|
||||
|
||||
return metrics # type: ignore[return-value]
|
||||
|
||||
def _evaluate_biencoder_batch(self, batch: Dict) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Evaluate a batch for bi-encoder model"""
|
||||
# Get query embeddings
|
||||
query_outputs = self.model(
|
||||
input_ids=batch['query_input_ids'],
|
||||
attention_mask=batch['query_attention_mask'],
|
||||
return_dense=True
|
||||
)
|
||||
query_embeddings = query_outputs['dense']
|
||||
|
||||
# Get passage embeddings
|
||||
# Handle multiple passages per query
|
||||
if batch['passage_input_ids'].dim() == 3:
|
||||
# Flatten batch dimension
|
||||
batch_size, num_passages, seq_len = batch['passage_input_ids'].shape
|
||||
passage_input_ids = batch['passage_input_ids'].view(-1, seq_len)
|
||||
passage_attention_mask = batch['passage_attention_mask'].view(-1, seq_len)
|
||||
else:
|
||||
passage_input_ids = batch['passage_input_ids']
|
||||
passage_attention_mask = batch['passage_attention_mask']
|
||||
|
||||
passage_outputs = self.model(
|
||||
input_ids=passage_input_ids,
|
||||
attention_mask=passage_attention_mask,
|
||||
return_dense=True
|
||||
)
|
||||
passage_embeddings = passage_outputs['dense']
|
||||
|
||||
return query_embeddings, passage_embeddings
|
||||
|
||||
def _evaluate_crossencoder_batch(self, batch: Dict) -> torch.Tensor:
|
||||
"""Evaluate a batch for cross-encoder model"""
|
||||
outputs = self.model(
|
||||
input_ids=batch['input_ids'],
|
||||
attention_mask=batch['attention_mask']
|
||||
)
|
||||
return outputs['logits']
|
||||
|
||||
def evaluate_retrieval_pipeline(
|
||||
self,
|
||||
queries: List[str],
|
||||
corpus: List[str],
|
||||
relevant_docs: Dict[str, List[int]],
|
||||
batch_size: int = 32
|
||||
) -> Dict[str, float]:
|
||||
"""
|
||||
Evaluate full retrieval pipeline
|
||||
|
||||
Args:
|
||||
queries: List of queries
|
||||
corpus: List of all documents
|
||||
relevant_docs: Dict mapping query to relevant doc indices
|
||||
batch_size: Batch size for encoding
|
||||
|
||||
Returns:
|
||||
Evaluation metrics
|
||||
"""
|
||||
logger.info(f"Evaluating retrieval pipeline with {len(queries)} queries over {len(corpus)} documents")
|
||||
|
||||
# Encode all documents
|
||||
logger.info("Encoding corpus...")
|
||||
corpus_embeddings = self._encode_texts(corpus, batch_size, desc="Encoding corpus")
|
||||
|
||||
# Encode queries and evaluate
|
||||
all_metrics = defaultdict(list)
|
||||
|
||||
for i in tqdm(range(0, len(queries), batch_size), desc="Evaluating queries"):
|
||||
batch_queries = queries[i:i + batch_size]
|
||||
batch_relevant = [relevant_docs.get(q, []) for q in batch_queries]
|
||||
|
||||
# Encode queries
|
||||
query_embeddings = self._encode_texts(batch_queries, batch_size, show_progress=False)
|
||||
|
||||
# Create labels matrix
|
||||
labels = np.zeros((len(batch_queries), len(corpus)))
|
||||
for j, relevant_indices in enumerate(batch_relevant):
|
||||
for idx in relevant_indices:
|
||||
if idx < len(corpus):
|
||||
labels[j, idx] = 1
|
||||
|
||||
# Compute metrics
|
||||
batch_metrics = compute_retrieval_metrics(
|
||||
query_embeddings,
|
||||
corpus_embeddings,
|
||||
labels,
|
||||
k_values=self.k_values
|
||||
)
|
||||
|
||||
# Accumulate metrics
|
||||
for metric, value in batch_metrics.items():
|
||||
if metric != 'scores':
|
||||
all_metrics[metric].append(value)
|
||||
|
||||
# Average metrics
|
||||
final_metrics = {
|
||||
metric: np.mean(values) for metric, values in all_metrics.items()
|
||||
}
|
||||
|
||||
return final_metrics # type: ignore[return-value]
|
||||
|
||||
def _encode_texts(
|
||||
self,
|
||||
texts: List[str],
|
||||
batch_size: int,
|
||||
desc: str = "Encoding",
|
||||
show_progress: bool = True
|
||||
) -> np.ndarray:
|
||||
"""Encode texts to embeddings"""
|
||||
if isinstance(self.model, BGEM3Model):
|
||||
# Use model's encode method
|
||||
embeddings = self.model.encode(
|
||||
texts,
|
||||
batch_size=batch_size,
|
||||
device=self.device
|
||||
)
|
||||
if isinstance(embeddings, torch.Tensor):
|
||||
embeddings = embeddings.cpu().numpy()
|
||||
else:
|
||||
# For cross-encoder, this doesn't make sense
|
||||
raise ValueError("Cannot encode texts with cross-encoder model")
|
||||
|
||||
return embeddings # type: ignore[return-value]
|
||||
|
||||
def evaluate_cross_lingual(
|
||||
self,
|
||||
queries: Dict[str, List[str]],
|
||||
corpus: List[str],
|
||||
relevant_docs: Dict[str, Dict[str, List[int]]],
|
||||
batch_size: int = 32
|
||||
) -> Dict[str, Dict[str, float]]:
|
||||
"""
|
||||
Evaluate cross-lingual retrieval
|
||||
|
||||
Args:
|
||||
queries: Dict mapping language to queries
|
||||
corpus: List of documents (assumed to be in source language)
|
||||
relevant_docs: Nested dict of language -> query -> relevant indices
|
||||
batch_size: Batch size
|
||||
|
||||
Returns:
|
||||
Metrics per language
|
||||
"""
|
||||
# Encode corpus once
|
||||
corpus_embeddings = self._encode_texts(corpus, batch_size, desc="Encoding corpus")
|
||||
|
||||
# Evaluate each language
|
||||
results = {}
|
||||
|
||||
for lang, lang_queries in queries.items():
|
||||
logger.info(f"Evaluating {lang} queries...")
|
||||
|
||||
# Encode queries
|
||||
query_embeddings = self._encode_texts(lang_queries, batch_size, desc=f"Encoding {lang} queries")
|
||||
|
||||
# Create labels
|
||||
labels = np.zeros((len(lang_queries), len(corpus)))
|
||||
for i, query in enumerate(lang_queries):
|
||||
if query in relevant_docs.get(lang, {}):
|
||||
for idx in relevant_docs[lang][query]:
|
||||
if idx < len(corpus):
|
||||
labels[i, idx] = 1
|
||||
|
||||
# Compute metrics
|
||||
metrics = compute_retrieval_metrics(
|
||||
query_embeddings,
|
||||
corpus_embeddings,
|
||||
labels,
|
||||
k_values=self.k_values
|
||||
)
|
||||
|
||||
results[lang] = metrics
|
||||
|
||||
# Compute average across languages
|
||||
if len(results) > 1:
|
||||
avg_metrics = {}
|
||||
for metric in results[list(results.keys())[0]]:
|
||||
if metric != 'scores':
|
||||
values = [results[lang][metric] for lang in results]
|
||||
avg_metrics[metric] = np.mean(values)
|
||||
results['average'] = avg_metrics
|
||||
|
||||
return results
|
||||
|
||||
def save_evaluation_results(
|
||||
self,
|
||||
metrics: Dict[str, float],
|
||||
output_dir: str,
|
||||
prefix: str = "eval"
|
||||
):
|
||||
"""Save evaluation results"""
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
# Save metrics
|
||||
metrics_file = os.path.join(output_dir, f"{prefix}_metrics.json")
|
||||
with open(metrics_file, 'w') as f:
|
||||
json.dump(metrics, f, indent=2)
|
||||
|
||||
# Save summary
|
||||
summary_file = os.path.join(output_dir, f"{prefix}_summary.txt")
|
||||
with open(summary_file, 'w') as f:
|
||||
f.write(f"Evaluation Results - {prefix}\n")
|
||||
f.write("=" * 50 + "\n\n")
|
||||
|
||||
for metric, value in sorted(metrics.items()):
|
||||
if isinstance(value, float):
|
||||
f.write(f"{metric:.<30} {value:.4f}\n")
|
||||
|
||||
logger.info(f"Saved evaluation results to {output_dir}")
|
||||
|
||||
|
||||
class InteractiveEvaluator:
|
||||
"""
|
||||
Interactive evaluation for qualitative analysis
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
retriever: BGEM3Model,
|
||||
reranker: Optional[BGERerankerModel],
|
||||
tokenizer,
|
||||
corpus: List[str],
|
||||
device: str = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
):
|
||||
self.retriever = retriever
|
||||
self.reranker = reranker
|
||||
self.tokenizer = tokenizer
|
||||
self.corpus = corpus
|
||||
self.device = device
|
||||
|
||||
# Pre-encode corpus
|
||||
logger.info("Pre-encoding corpus...")
|
||||
self.corpus_embeddings = self._encode_corpus()
|
||||
logger.info(f"Encoded {len(self.corpus)} documents")
|
||||
|
||||
def _encode_corpus(self) -> np.ndarray:
|
||||
"""Pre-encode all corpus documents"""
|
||||
embeddings = self.retriever.encode(
|
||||
self.corpus,
|
||||
batch_size=32,
|
||||
device=self.device
|
||||
)
|
||||
if isinstance(embeddings, torch.Tensor):
|
||||
embeddings = embeddings.cpu().numpy()
|
||||
return embeddings # type: ignore[return-value]
|
||||
|
||||
def search(
|
||||
self,
|
||||
query: str,
|
||||
top_k: int = 10,
|
||||
rerank: bool = True
|
||||
) -> List[Tuple[int, float, str]]:
|
||||
"""
|
||||
Search for query in corpus
|
||||
|
||||
Args:
|
||||
query: Search query
|
||||
top_k: Number of results to return
|
||||
rerank: Whether to use reranker
|
||||
|
||||
Returns:
|
||||
List of (index, score, text) tuples
|
||||
"""
|
||||
# Encode query
|
||||
query_embedding = self.retriever.encode(
|
||||
[query],
|
||||
device=self.device
|
||||
)
|
||||
if isinstance(query_embedding, torch.Tensor):
|
||||
query_embedding = query_embedding.cpu().numpy()
|
||||
|
||||
# Compute similarities
|
||||
similarities = np.dot(query_embedding, self.corpus_embeddings.T).squeeze() # type: ignore[arg-type]
|
||||
|
||||
# Get top-k
|
||||
top_indices = np.argsort(similarities)[::-1][:top_k * 2 if rerank else top_k]
|
||||
|
||||
results = []
|
||||
for idx in top_indices:
|
||||
results.append((int(idx), float(similarities[idx]), self.corpus[idx]))
|
||||
|
||||
# Rerank if requested
|
||||
if rerank and self.reranker:
|
||||
reranked_indices = self.reranker.rerank(
|
||||
query,
|
||||
[self.corpus[idx] for idx in top_indices],
|
||||
return_scores=False
|
||||
)
|
||||
|
||||
# Reorder results
|
||||
reranked_results = []
|
||||
for i, rerank_idx in enumerate(reranked_indices[:top_k]):
|
||||
original_idx = top_indices[rerank_idx] # type: ignore[index]
|
||||
reranked_results.append((
|
||||
int(original_idx),
|
||||
float(i + 1), # Rank as score
|
||||
self.corpus[original_idx]
|
||||
))
|
||||
results = reranked_results
|
||||
|
||||
return results[:top_k]
|
||||
|
||||
def evaluate_query_interactively(
|
||||
self,
|
||||
query: str,
|
||||
relevant_indices: List[int],
|
||||
top_k: int = 10
|
||||
) -> Dict[str, float]:
|
||||
"""
|
||||
Evaluate a single query interactively
|
||||
|
||||
Args:
|
||||
query: Query to evaluate
|
||||
relevant_indices: Ground truth relevant document indices
|
||||
top_k: Number of results to consider
|
||||
|
||||
Returns:
|
||||
Metrics for this query
|
||||
"""
|
||||
# Get search results
|
||||
results = self.search(query, top_k=top_k, rerank=True)
|
||||
retrieved_indices = [r[0] for r in results]
|
||||
|
||||
# Calculate metrics
|
||||
metrics = {}
|
||||
|
||||
# Precision@k
|
||||
relevant_retrieved = sum(1 for idx in retrieved_indices if idx in relevant_indices)
|
||||
metrics[f'precision@{top_k}'] = relevant_retrieved / top_k
|
||||
|
||||
# Recall@k
|
||||
if relevant_indices:
|
||||
metrics[f'recall@{top_k}'] = relevant_retrieved / len(relevant_indices)
|
||||
else:
|
||||
metrics[f'recall@{top_k}'] = 0.0
|
||||
|
||||
# MRR
|
||||
for i, idx in enumerate(retrieved_indices):
|
||||
if idx in relevant_indices:
|
||||
metrics['mrr'] = 1.0 / (i + 1)
|
||||
break
|
||||
else:
|
||||
metrics['mrr'] = 0.0
|
||||
|
||||
# Print results for inspection
|
||||
print(f"\nQuery: {query}")
|
||||
print(f"Relevant docs: {relevant_indices}")
|
||||
print(f"\nTop {top_k} results:")
|
||||
for i, (idx, score, text) in enumerate(results):
|
||||
relevance = "✓" if idx in relevant_indices else "✗"
|
||||
print(f"{i + 1}. [{relevance}] (idx={idx}, score={score:.3f}) {text[:100]}...")
|
||||
|
||||
print(f"\nMetrics: {metrics}")
|
||||
|
||||
return metrics
|
||||
Reference in New Issue
Block a user