512 lines
17 KiB
Python
512 lines
17 KiB
Python
import torch
|
|
import numpy as np
|
|
from typing import List, Dict, Optional, Tuple, Union
|
|
from tqdm import tqdm
|
|
import logging
|
|
from torch.utils.data import DataLoader
|
|
import json
|
|
import os
|
|
from collections import defaultdict
|
|
|
|
from models.bge_m3 import BGEM3Model
|
|
from models.bge_reranker import BGERerankerModel
|
|
from .metrics import (
|
|
compute_retrieval_metrics,
|
|
compute_reranker_metrics,
|
|
evaluate_cross_lingual_retrieval,
|
|
MetricsTracker
|
|
)
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class RetrievalEvaluator:
|
|
"""
|
|
Comprehensive evaluator for retrieval models
|
|
|
|
Notes:
|
|
- All parameters (model, tokenizer, device, metrics, k_values, etc.) should be passed from config-driven scripts.
|
|
- This class does not load config directly; pass config values from your main script for consistency.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
model: Union[BGEM3Model, BGERerankerModel],
|
|
tokenizer,
|
|
device: str = 'cuda' if torch.cuda.is_available() else 'cpu',
|
|
metrics: List[str] = ['recall', 'precision', 'map', 'mrr', 'ndcg'],
|
|
k_values: List[int] = [1, 5, 10, 20, 100]
|
|
):
|
|
self.model = model
|
|
self.tokenizer = tokenizer
|
|
self.device = device
|
|
self.metrics = metrics
|
|
self.k_values = k_values
|
|
|
|
# Move model to device
|
|
self.model.to(device)
|
|
self.model.eval()
|
|
|
|
# Metrics tracker
|
|
self.tracker = MetricsTracker()
|
|
|
|
def evaluate(
|
|
self,
|
|
dataloader: DataLoader,
|
|
desc: str = "Evaluating",
|
|
return_embeddings: bool = False
|
|
) -> Dict[str, Union[float, list, np.ndarray]]:
|
|
"""
|
|
Evaluate model on dataset
|
|
|
|
Args:
|
|
dataloader: DataLoader with evaluation data
|
|
desc: Description for progress bar
|
|
return_embeddings: Whether to return computed embeddings
|
|
|
|
Returns:
|
|
Dictionary of evaluation metrics
|
|
"""
|
|
self.model.eval()
|
|
|
|
all_query_embeddings = []
|
|
all_passage_embeddings = []
|
|
all_labels = []
|
|
all_queries = []
|
|
all_passages = []
|
|
|
|
with torch.no_grad():
|
|
for batch in tqdm(dataloader, desc=desc):
|
|
# Move batch to device
|
|
batch = {k: v.to(self.device) if isinstance(v, torch.Tensor) else v
|
|
for k, v in batch.items()}
|
|
|
|
# Get embeddings based on model type
|
|
if isinstance(self.model, BGEM3Model):
|
|
# Bi-encoder evaluation
|
|
query_embeddings, passage_embeddings = self._evaluate_biencoder_batch(batch)
|
|
|
|
all_query_embeddings.append(query_embeddings.cpu())
|
|
all_passage_embeddings.append(passage_embeddings.cpu())
|
|
|
|
else:
|
|
# Cross-encoder evaluation
|
|
scores = self._evaluate_crossencoder_batch(batch)
|
|
# For cross-encoder, we store scores instead of embeddings
|
|
all_query_embeddings.append(scores.cpu())
|
|
|
|
# Collect labels
|
|
if 'labels' in batch:
|
|
all_labels.append(batch['labels'].cpu())
|
|
|
|
# Collect text if available
|
|
if 'query' in batch:
|
|
all_queries.extend(batch['query'])
|
|
if 'passages' in batch:
|
|
all_passages.extend(batch['passages'])
|
|
|
|
# Concatenate results
|
|
if isinstance(self.model, BGEM3Model):
|
|
all_query_embeddings = torch.cat(all_query_embeddings, dim=0).numpy()
|
|
all_passage_embeddings = torch.cat(all_passage_embeddings, dim=0).numpy()
|
|
all_labels = torch.cat(all_labels, dim=0).numpy() if all_labels else None
|
|
|
|
# Compute retrieval metrics
|
|
if all_labels is not None:
|
|
metrics = compute_retrieval_metrics(
|
|
all_query_embeddings,
|
|
all_passage_embeddings,
|
|
all_labels,
|
|
k_values=self.k_values
|
|
)
|
|
else:
|
|
metrics = {}
|
|
|
|
else:
|
|
# Cross-encoder metrics
|
|
all_scores = torch.cat(all_query_embeddings, dim=0).numpy()
|
|
all_labels = torch.cat(all_labels, dim=0).numpy() if all_labels else None
|
|
|
|
if all_labels is not None:
|
|
# Determine if we have groups (listwise) or pairs
|
|
if hasattr(dataloader.dataset, 'train_group_size'):
|
|
groups = [getattr(dataloader.dataset, 'train_group_size')] * len(dataloader.dataset) # type: ignore[arg-type]
|
|
else:
|
|
groups = None
|
|
|
|
metrics = compute_reranker_metrics(
|
|
all_scores,
|
|
all_labels,
|
|
groups=groups,
|
|
k_values=self.k_values
|
|
)
|
|
else:
|
|
metrics = {}
|
|
|
|
# Update tracker
|
|
self.tracker.update(metrics, step=0)
|
|
|
|
# Return embeddings if requested
|
|
if return_embeddings and isinstance(self.model, BGEM3Model):
|
|
metrics['query_embeddings'] = all_query_embeddings # type: ignore[assignment]
|
|
metrics['passage_embeddings'] = all_passage_embeddings # type: ignore[assignment]
|
|
metrics['queries'] = all_queries # type: ignore[assignment]
|
|
metrics['passages'] = all_passages # type: ignore[assignment]
|
|
|
|
return metrics # type: ignore[return-value]
|
|
|
|
def _evaluate_biencoder_batch(self, batch: Dict) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
"""Evaluate a batch for bi-encoder model"""
|
|
# Get query embeddings
|
|
query_outputs = self.model(
|
|
input_ids=batch['query_input_ids'],
|
|
attention_mask=batch['query_attention_mask'],
|
|
return_dense=True
|
|
)
|
|
query_embeddings = query_outputs['dense']
|
|
|
|
# Get passage embeddings
|
|
# Handle multiple passages per query
|
|
if batch['passage_input_ids'].dim() == 3:
|
|
# Flatten batch dimension
|
|
batch_size, num_passages, seq_len = batch['passage_input_ids'].shape
|
|
passage_input_ids = batch['passage_input_ids'].view(-1, seq_len)
|
|
passage_attention_mask = batch['passage_attention_mask'].view(-1, seq_len)
|
|
else:
|
|
passage_input_ids = batch['passage_input_ids']
|
|
passage_attention_mask = batch['passage_attention_mask']
|
|
|
|
passage_outputs = self.model(
|
|
input_ids=passage_input_ids,
|
|
attention_mask=passage_attention_mask,
|
|
return_dense=True
|
|
)
|
|
passage_embeddings = passage_outputs['dense']
|
|
|
|
return query_embeddings, passage_embeddings
|
|
|
|
def _evaluate_crossencoder_batch(self, batch: Dict) -> torch.Tensor:
|
|
"""Evaluate a batch for cross-encoder model"""
|
|
outputs = self.model(
|
|
input_ids=batch['input_ids'],
|
|
attention_mask=batch['attention_mask']
|
|
)
|
|
return outputs['logits']
|
|
|
|
def evaluate_retrieval_pipeline(
|
|
self,
|
|
queries: List[str],
|
|
corpus: List[str],
|
|
relevant_docs: Dict[str, List[int]],
|
|
batch_size: int = 32
|
|
) -> Dict[str, float]:
|
|
"""
|
|
Evaluate full retrieval pipeline
|
|
|
|
Args:
|
|
queries: List of queries
|
|
corpus: List of all documents
|
|
relevant_docs: Dict mapping query to relevant doc indices
|
|
batch_size: Batch size for encoding
|
|
|
|
Returns:
|
|
Evaluation metrics
|
|
"""
|
|
logger.info(f"Evaluating retrieval pipeline with {len(queries)} queries over {len(corpus)} documents")
|
|
|
|
# Encode all documents
|
|
logger.info("Encoding corpus...")
|
|
corpus_embeddings = self._encode_texts(corpus, batch_size, desc="Encoding corpus")
|
|
|
|
# Encode queries and evaluate
|
|
all_metrics = defaultdict(list)
|
|
|
|
for i in tqdm(range(0, len(queries), batch_size), desc="Evaluating queries"):
|
|
batch_queries = queries[i:i + batch_size]
|
|
batch_relevant = [relevant_docs.get(q, []) for q in batch_queries]
|
|
|
|
# Encode queries
|
|
query_embeddings = self._encode_texts(batch_queries, batch_size, show_progress=False)
|
|
|
|
# Create labels matrix
|
|
labels = np.zeros((len(batch_queries), len(corpus)))
|
|
for j, relevant_indices in enumerate(batch_relevant):
|
|
for idx in relevant_indices:
|
|
if idx < len(corpus):
|
|
labels[j, idx] = 1
|
|
|
|
# Compute metrics
|
|
batch_metrics = compute_retrieval_metrics(
|
|
query_embeddings,
|
|
corpus_embeddings,
|
|
labels,
|
|
k_values=self.k_values
|
|
)
|
|
|
|
# Accumulate metrics
|
|
for metric, value in batch_metrics.items():
|
|
if metric != 'scores':
|
|
all_metrics[metric].append(value)
|
|
|
|
# Average metrics
|
|
final_metrics = {
|
|
metric: np.mean(values) for metric, values in all_metrics.items()
|
|
}
|
|
|
|
return final_metrics # type: ignore[return-value]
|
|
|
|
def _encode_texts(
|
|
self,
|
|
texts: List[str],
|
|
batch_size: int,
|
|
desc: str = "Encoding",
|
|
show_progress: bool = True
|
|
) -> np.ndarray:
|
|
"""Encode texts to embeddings"""
|
|
if isinstance(self.model, BGEM3Model):
|
|
# Use model's encode method
|
|
embeddings = self.model.encode(
|
|
texts,
|
|
batch_size=batch_size,
|
|
device=self.device
|
|
)
|
|
if isinstance(embeddings, torch.Tensor):
|
|
embeddings = embeddings.cpu().numpy()
|
|
else:
|
|
# For cross-encoder, this doesn't make sense
|
|
raise ValueError("Cannot encode texts with cross-encoder model")
|
|
|
|
return embeddings # type: ignore[return-value]
|
|
|
|
def evaluate_cross_lingual(
|
|
self,
|
|
queries: Dict[str, List[str]],
|
|
corpus: List[str],
|
|
relevant_docs: Dict[str, Dict[str, List[int]]],
|
|
batch_size: int = 32
|
|
) -> Dict[str, Dict[str, float]]:
|
|
"""
|
|
Evaluate cross-lingual retrieval
|
|
|
|
Args:
|
|
queries: Dict mapping language to queries
|
|
corpus: List of documents (assumed to be in source language)
|
|
relevant_docs: Nested dict of language -> query -> relevant indices
|
|
batch_size: Batch size
|
|
|
|
Returns:
|
|
Metrics per language
|
|
"""
|
|
# Encode corpus once
|
|
corpus_embeddings = self._encode_texts(corpus, batch_size, desc="Encoding corpus")
|
|
|
|
# Evaluate each language
|
|
results = {}
|
|
|
|
for lang, lang_queries in queries.items():
|
|
logger.info(f"Evaluating {lang} queries...")
|
|
|
|
# Encode queries
|
|
query_embeddings = self._encode_texts(lang_queries, batch_size, desc=f"Encoding {lang} queries")
|
|
|
|
# Create labels
|
|
labels = np.zeros((len(lang_queries), len(corpus)))
|
|
for i, query in enumerate(lang_queries):
|
|
if query in relevant_docs.get(lang, {}):
|
|
for idx in relevant_docs[lang][query]:
|
|
if idx < len(corpus):
|
|
labels[i, idx] = 1
|
|
|
|
# Compute metrics
|
|
metrics = compute_retrieval_metrics(
|
|
query_embeddings,
|
|
corpus_embeddings,
|
|
labels,
|
|
k_values=self.k_values
|
|
)
|
|
|
|
results[lang] = metrics
|
|
|
|
# Compute average across languages
|
|
if len(results) > 1:
|
|
avg_metrics = {}
|
|
for metric in results[list(results.keys())[0]]:
|
|
if metric != 'scores':
|
|
values = [results[lang][metric] for lang in results]
|
|
avg_metrics[metric] = np.mean(values)
|
|
results['average'] = avg_metrics
|
|
|
|
return results
|
|
|
|
def save_evaluation_results(
|
|
self,
|
|
metrics: Dict[str, float],
|
|
output_dir: str,
|
|
prefix: str = "eval"
|
|
):
|
|
"""Save evaluation results"""
|
|
os.makedirs(output_dir, exist_ok=True)
|
|
|
|
# Save metrics
|
|
metrics_file = os.path.join(output_dir, f"{prefix}_metrics.json")
|
|
with open(metrics_file, 'w') as f:
|
|
json.dump(metrics, f, indent=2)
|
|
|
|
# Save summary
|
|
summary_file = os.path.join(output_dir, f"{prefix}_summary.txt")
|
|
with open(summary_file, 'w') as f:
|
|
f.write(f"Evaluation Results - {prefix}\n")
|
|
f.write("=" * 50 + "\n\n")
|
|
|
|
for metric, value in sorted(metrics.items()):
|
|
if isinstance(value, float):
|
|
f.write(f"{metric:.<30} {value:.4f}\n")
|
|
|
|
logger.info(f"Saved evaluation results to {output_dir}")
|
|
|
|
|
|
class InteractiveEvaluator:
|
|
"""
|
|
Interactive evaluation for qualitative analysis
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
retriever: BGEM3Model,
|
|
reranker: Optional[BGERerankerModel],
|
|
tokenizer,
|
|
corpus: List[str],
|
|
device: str = 'cuda' if torch.cuda.is_available() else 'cpu'
|
|
):
|
|
self.retriever = retriever
|
|
self.reranker = reranker
|
|
self.tokenizer = tokenizer
|
|
self.corpus = corpus
|
|
self.device = device
|
|
|
|
# Pre-encode corpus
|
|
logger.info("Pre-encoding corpus...")
|
|
self.corpus_embeddings = self._encode_corpus()
|
|
logger.info(f"Encoded {len(self.corpus)} documents")
|
|
|
|
def _encode_corpus(self) -> np.ndarray:
|
|
"""Pre-encode all corpus documents"""
|
|
embeddings = self.retriever.encode(
|
|
self.corpus,
|
|
batch_size=32,
|
|
device=self.device
|
|
)
|
|
if isinstance(embeddings, torch.Tensor):
|
|
embeddings = embeddings.cpu().numpy()
|
|
return embeddings # type: ignore[return-value]
|
|
|
|
def search(
|
|
self,
|
|
query: str,
|
|
top_k: int = 10,
|
|
rerank: bool = True
|
|
) -> List[Tuple[int, float, str]]:
|
|
"""
|
|
Search for query in corpus
|
|
|
|
Args:
|
|
query: Search query
|
|
top_k: Number of results to return
|
|
rerank: Whether to use reranker
|
|
|
|
Returns:
|
|
List of (index, score, text) tuples
|
|
"""
|
|
# Encode query
|
|
query_embedding = self.retriever.encode(
|
|
[query],
|
|
device=self.device
|
|
)
|
|
if isinstance(query_embedding, torch.Tensor):
|
|
query_embedding = query_embedding.cpu().numpy()
|
|
|
|
# Compute similarities
|
|
similarities = np.dot(query_embedding, self.corpus_embeddings.T).squeeze() # type: ignore[arg-type]
|
|
|
|
# Get top-k
|
|
top_indices = np.argsort(similarities)[::-1][:top_k * 2 if rerank else top_k]
|
|
|
|
results = []
|
|
for idx in top_indices:
|
|
results.append((int(idx), float(similarities[idx]), self.corpus[idx]))
|
|
|
|
# Rerank if requested
|
|
if rerank and self.reranker:
|
|
reranked_indices = self.reranker.rerank(
|
|
query,
|
|
[self.corpus[idx] for idx in top_indices],
|
|
return_scores=False
|
|
)
|
|
|
|
# Reorder results
|
|
reranked_results = []
|
|
for i, rerank_idx in enumerate(reranked_indices[:top_k]):
|
|
original_idx = top_indices[rerank_idx] # type: ignore[index]
|
|
reranked_results.append((
|
|
int(original_idx),
|
|
float(i + 1), # Rank as score
|
|
self.corpus[original_idx]
|
|
))
|
|
results = reranked_results
|
|
|
|
return results[:top_k]
|
|
|
|
def evaluate_query_interactively(
|
|
self,
|
|
query: str,
|
|
relevant_indices: List[int],
|
|
top_k: int = 10
|
|
) -> Dict[str, float]:
|
|
"""
|
|
Evaluate a single query interactively
|
|
|
|
Args:
|
|
query: Query to evaluate
|
|
relevant_indices: Ground truth relevant document indices
|
|
top_k: Number of results to consider
|
|
|
|
Returns:
|
|
Metrics for this query
|
|
"""
|
|
# Get search results
|
|
results = self.search(query, top_k=top_k, rerank=True)
|
|
retrieved_indices = [r[0] for r in results]
|
|
|
|
# Calculate metrics
|
|
metrics = {}
|
|
|
|
# Precision@k
|
|
relevant_retrieved = sum(1 for idx in retrieved_indices if idx in relevant_indices)
|
|
metrics[f'precision@{top_k}'] = relevant_retrieved / top_k
|
|
|
|
# Recall@k
|
|
if relevant_indices:
|
|
metrics[f'recall@{top_k}'] = relevant_retrieved / len(relevant_indices)
|
|
else:
|
|
metrics[f'recall@{top_k}'] = 0.0
|
|
|
|
# MRR
|
|
for i, idx in enumerate(retrieved_indices):
|
|
if idx in relevant_indices:
|
|
metrics['mrr'] = 1.0 / (i + 1)
|
|
break
|
|
else:
|
|
metrics['mrr'] = 0.0
|
|
|
|
# Print results for inspection
|
|
print(f"\nQuery: {query}")
|
|
print(f"Relevant docs: {relevant_indices}")
|
|
print(f"\nTop {top_k} results:")
|
|
for i, (idx, score, text) in enumerate(results):
|
|
relevance = "✓" if idx in relevant_indices else "✗"
|
|
print(f"{i + 1}. [{relevance}] (idx={idx}, score={score:.3f}) {text[:100]}...")
|
|
|
|
print(f"\nMetrics: {metrics}")
|
|
|
|
return metrics
|