2025-07-22 16:55:25 +08:00

512 lines
17 KiB
Python

import torch
import numpy as np
from typing import List, Dict, Optional, Tuple, Union
from tqdm import tqdm
import logging
from torch.utils.data import DataLoader
import json
import os
from collections import defaultdict
from models.bge_m3 import BGEM3Model
from models.bge_reranker import BGERerankerModel
from .metrics import (
compute_retrieval_metrics,
compute_reranker_metrics,
evaluate_cross_lingual_retrieval,
MetricsTracker
)
logger = logging.getLogger(__name__)
class RetrievalEvaluator:
"""
Comprehensive evaluator for retrieval models
Notes:
- All parameters (model, tokenizer, device, metrics, k_values, etc.) should be passed from config-driven scripts.
- This class does not load config directly; pass config values from your main script for consistency.
"""
def __init__(
self,
model: Union[BGEM3Model, BGERerankerModel],
tokenizer,
device: str = 'cuda' if torch.cuda.is_available() else 'cpu',
metrics: List[str] = ['recall', 'precision', 'map', 'mrr', 'ndcg'],
k_values: List[int] = [1, 5, 10, 20, 100]
):
self.model = model
self.tokenizer = tokenizer
self.device = device
self.metrics = metrics
self.k_values = k_values
# Move model to device
self.model.to(device)
self.model.eval()
# Metrics tracker
self.tracker = MetricsTracker()
def evaluate(
self,
dataloader: DataLoader,
desc: str = "Evaluating",
return_embeddings: bool = False
) -> Dict[str, Union[float, list, np.ndarray]]:
"""
Evaluate model on dataset
Args:
dataloader: DataLoader with evaluation data
desc: Description for progress bar
return_embeddings: Whether to return computed embeddings
Returns:
Dictionary of evaluation metrics
"""
self.model.eval()
all_query_embeddings = []
all_passage_embeddings = []
all_labels = []
all_queries = []
all_passages = []
with torch.no_grad():
for batch in tqdm(dataloader, desc=desc):
# Move batch to device
batch = {k: v.to(self.device) if isinstance(v, torch.Tensor) else v
for k, v in batch.items()}
# Get embeddings based on model type
if isinstance(self.model, BGEM3Model):
# Bi-encoder evaluation
query_embeddings, passage_embeddings = self._evaluate_biencoder_batch(batch)
all_query_embeddings.append(query_embeddings.cpu())
all_passage_embeddings.append(passage_embeddings.cpu())
else:
# Cross-encoder evaluation
scores = self._evaluate_crossencoder_batch(batch)
# For cross-encoder, we store scores instead of embeddings
all_query_embeddings.append(scores.cpu())
# Collect labels
if 'labels' in batch:
all_labels.append(batch['labels'].cpu())
# Collect text if available
if 'query' in batch:
all_queries.extend(batch['query'])
if 'passages' in batch:
all_passages.extend(batch['passages'])
# Concatenate results
if isinstance(self.model, BGEM3Model):
all_query_embeddings = torch.cat(all_query_embeddings, dim=0).numpy()
all_passage_embeddings = torch.cat(all_passage_embeddings, dim=0).numpy()
all_labels = torch.cat(all_labels, dim=0).numpy() if all_labels else None
# Compute retrieval metrics
if all_labels is not None:
metrics = compute_retrieval_metrics(
all_query_embeddings,
all_passage_embeddings,
all_labels,
k_values=self.k_values
)
else:
metrics = {}
else:
# Cross-encoder metrics
all_scores = torch.cat(all_query_embeddings, dim=0).numpy()
all_labels = torch.cat(all_labels, dim=0).numpy() if all_labels else None
if all_labels is not None:
# Determine if we have groups (listwise) or pairs
if hasattr(dataloader.dataset, 'train_group_size'):
groups = [getattr(dataloader.dataset, 'train_group_size')] * len(dataloader.dataset) # type: ignore[arg-type]
else:
groups = None
metrics = compute_reranker_metrics(
all_scores,
all_labels,
groups=groups,
k_values=self.k_values
)
else:
metrics = {}
# Update tracker
self.tracker.update(metrics, step=0)
# Return embeddings if requested
if return_embeddings and isinstance(self.model, BGEM3Model):
metrics['query_embeddings'] = all_query_embeddings # type: ignore[assignment]
metrics['passage_embeddings'] = all_passage_embeddings # type: ignore[assignment]
metrics['queries'] = all_queries # type: ignore[assignment]
metrics['passages'] = all_passages # type: ignore[assignment]
return metrics # type: ignore[return-value]
def _evaluate_biencoder_batch(self, batch: Dict) -> Tuple[torch.Tensor, torch.Tensor]:
"""Evaluate a batch for bi-encoder model"""
# Get query embeddings
query_outputs = self.model(
input_ids=batch['query_input_ids'],
attention_mask=batch['query_attention_mask'],
return_dense=True
)
query_embeddings = query_outputs['dense']
# Get passage embeddings
# Handle multiple passages per query
if batch['passage_input_ids'].dim() == 3:
# Flatten batch dimension
batch_size, num_passages, seq_len = batch['passage_input_ids'].shape
passage_input_ids = batch['passage_input_ids'].view(-1, seq_len)
passage_attention_mask = batch['passage_attention_mask'].view(-1, seq_len)
else:
passage_input_ids = batch['passage_input_ids']
passage_attention_mask = batch['passage_attention_mask']
passage_outputs = self.model(
input_ids=passage_input_ids,
attention_mask=passage_attention_mask,
return_dense=True
)
passage_embeddings = passage_outputs['dense']
return query_embeddings, passage_embeddings
def _evaluate_crossencoder_batch(self, batch: Dict) -> torch.Tensor:
"""Evaluate a batch for cross-encoder model"""
outputs = self.model(
input_ids=batch['input_ids'],
attention_mask=batch['attention_mask']
)
return outputs['logits']
def evaluate_retrieval_pipeline(
self,
queries: List[str],
corpus: List[str],
relevant_docs: Dict[str, List[int]],
batch_size: int = 32
) -> Dict[str, float]:
"""
Evaluate full retrieval pipeline
Args:
queries: List of queries
corpus: List of all documents
relevant_docs: Dict mapping query to relevant doc indices
batch_size: Batch size for encoding
Returns:
Evaluation metrics
"""
logger.info(f"Evaluating retrieval pipeline with {len(queries)} queries over {len(corpus)} documents")
# Encode all documents
logger.info("Encoding corpus...")
corpus_embeddings = self._encode_texts(corpus, batch_size, desc="Encoding corpus")
# Encode queries and evaluate
all_metrics = defaultdict(list)
for i in tqdm(range(0, len(queries), batch_size), desc="Evaluating queries"):
batch_queries = queries[i:i + batch_size]
batch_relevant = [relevant_docs.get(q, []) for q in batch_queries]
# Encode queries
query_embeddings = self._encode_texts(batch_queries, batch_size, show_progress=False)
# Create labels matrix
labels = np.zeros((len(batch_queries), len(corpus)))
for j, relevant_indices in enumerate(batch_relevant):
for idx in relevant_indices:
if idx < len(corpus):
labels[j, idx] = 1
# Compute metrics
batch_metrics = compute_retrieval_metrics(
query_embeddings,
corpus_embeddings,
labels,
k_values=self.k_values
)
# Accumulate metrics
for metric, value in batch_metrics.items():
if metric != 'scores':
all_metrics[metric].append(value)
# Average metrics
final_metrics = {
metric: np.mean(values) for metric, values in all_metrics.items()
}
return final_metrics # type: ignore[return-value]
def _encode_texts(
self,
texts: List[str],
batch_size: int,
desc: str = "Encoding",
show_progress: bool = True
) -> np.ndarray:
"""Encode texts to embeddings"""
if isinstance(self.model, BGEM3Model):
# Use model's encode method
embeddings = self.model.encode(
texts,
batch_size=batch_size,
device=self.device
)
if isinstance(embeddings, torch.Tensor):
embeddings = embeddings.cpu().numpy()
else:
# For cross-encoder, this doesn't make sense
raise ValueError("Cannot encode texts with cross-encoder model")
return embeddings # type: ignore[return-value]
def evaluate_cross_lingual(
self,
queries: Dict[str, List[str]],
corpus: List[str],
relevant_docs: Dict[str, Dict[str, List[int]]],
batch_size: int = 32
) -> Dict[str, Dict[str, float]]:
"""
Evaluate cross-lingual retrieval
Args:
queries: Dict mapping language to queries
corpus: List of documents (assumed to be in source language)
relevant_docs: Nested dict of language -> query -> relevant indices
batch_size: Batch size
Returns:
Metrics per language
"""
# Encode corpus once
corpus_embeddings = self._encode_texts(corpus, batch_size, desc="Encoding corpus")
# Evaluate each language
results = {}
for lang, lang_queries in queries.items():
logger.info(f"Evaluating {lang} queries...")
# Encode queries
query_embeddings = self._encode_texts(lang_queries, batch_size, desc=f"Encoding {lang} queries")
# Create labels
labels = np.zeros((len(lang_queries), len(corpus)))
for i, query in enumerate(lang_queries):
if query in relevant_docs.get(lang, {}):
for idx in relevant_docs[lang][query]:
if idx < len(corpus):
labels[i, idx] = 1
# Compute metrics
metrics = compute_retrieval_metrics(
query_embeddings,
corpus_embeddings,
labels,
k_values=self.k_values
)
results[lang] = metrics
# Compute average across languages
if len(results) > 1:
avg_metrics = {}
for metric in results[list(results.keys())[0]]:
if metric != 'scores':
values = [results[lang][metric] for lang in results]
avg_metrics[metric] = np.mean(values)
results['average'] = avg_metrics
return results
def save_evaluation_results(
self,
metrics: Dict[str, float],
output_dir: str,
prefix: str = "eval"
):
"""Save evaluation results"""
os.makedirs(output_dir, exist_ok=True)
# Save metrics
metrics_file = os.path.join(output_dir, f"{prefix}_metrics.json")
with open(metrics_file, 'w') as f:
json.dump(metrics, f, indent=2)
# Save summary
summary_file = os.path.join(output_dir, f"{prefix}_summary.txt")
with open(summary_file, 'w') as f:
f.write(f"Evaluation Results - {prefix}\n")
f.write("=" * 50 + "\n\n")
for metric, value in sorted(metrics.items()):
if isinstance(value, float):
f.write(f"{metric:.<30} {value:.4f}\n")
logger.info(f"Saved evaluation results to {output_dir}")
class InteractiveEvaluator:
"""
Interactive evaluation for qualitative analysis
"""
def __init__(
self,
retriever: BGEM3Model,
reranker: Optional[BGERerankerModel],
tokenizer,
corpus: List[str],
device: str = 'cuda' if torch.cuda.is_available() else 'cpu'
):
self.retriever = retriever
self.reranker = reranker
self.tokenizer = tokenizer
self.corpus = corpus
self.device = device
# Pre-encode corpus
logger.info("Pre-encoding corpus...")
self.corpus_embeddings = self._encode_corpus()
logger.info(f"Encoded {len(self.corpus)} documents")
def _encode_corpus(self) -> np.ndarray:
"""Pre-encode all corpus documents"""
embeddings = self.retriever.encode(
self.corpus,
batch_size=32,
device=self.device
)
if isinstance(embeddings, torch.Tensor):
embeddings = embeddings.cpu().numpy()
return embeddings # type: ignore[return-value]
def search(
self,
query: str,
top_k: int = 10,
rerank: bool = True
) -> List[Tuple[int, float, str]]:
"""
Search for query in corpus
Args:
query: Search query
top_k: Number of results to return
rerank: Whether to use reranker
Returns:
List of (index, score, text) tuples
"""
# Encode query
query_embedding = self.retriever.encode(
[query],
device=self.device
)
if isinstance(query_embedding, torch.Tensor):
query_embedding = query_embedding.cpu().numpy()
# Compute similarities
similarities = np.dot(query_embedding, self.corpus_embeddings.T).squeeze() # type: ignore[arg-type]
# Get top-k
top_indices = np.argsort(similarities)[::-1][:top_k * 2 if rerank else top_k]
results = []
for idx in top_indices:
results.append((int(idx), float(similarities[idx]), self.corpus[idx]))
# Rerank if requested
if rerank and self.reranker:
reranked_indices = self.reranker.rerank(
query,
[self.corpus[idx] for idx in top_indices],
return_scores=False
)
# Reorder results
reranked_results = []
for i, rerank_idx in enumerate(reranked_indices[:top_k]):
original_idx = top_indices[rerank_idx] # type: ignore[index]
reranked_results.append((
int(original_idx),
float(i + 1), # Rank as score
self.corpus[original_idx]
))
results = reranked_results
return results[:top_k]
def evaluate_query_interactively(
self,
query: str,
relevant_indices: List[int],
top_k: int = 10
) -> Dict[str, float]:
"""
Evaluate a single query interactively
Args:
query: Query to evaluate
relevant_indices: Ground truth relevant document indices
top_k: Number of results to consider
Returns:
Metrics for this query
"""
# Get search results
results = self.search(query, top_k=top_k, rerank=True)
retrieved_indices = [r[0] for r in results]
# Calculate metrics
metrics = {}
# Precision@k
relevant_retrieved = sum(1 for idx in retrieved_indices if idx in relevant_indices)
metrics[f'precision@{top_k}'] = relevant_retrieved / top_k
# Recall@k
if relevant_indices:
metrics[f'recall@{top_k}'] = relevant_retrieved / len(relevant_indices)
else:
metrics[f'recall@{top_k}'] = 0.0
# MRR
for i, idx in enumerate(retrieved_indices):
if idx in relevant_indices:
metrics['mrr'] = 1.0 / (i + 1)
break
else:
metrics['mrr'] = 0.0
# Print results for inspection
print(f"\nQuery: {query}")
print(f"Relevant docs: {relevant_indices}")
print(f"\nTop {top_k} results:")
for i, (idx, score, text) in enumerate(results):
relevance = "" if idx in relevant_indices else ""
print(f"{i + 1}. [{relevance}] (idx={idx}, score={score:.3f}) {text[:100]}...")
print(f"\nMetrics: {metrics}")
return metrics