543 lines
17 KiB
Python
543 lines
17 KiB
Python
import numpy as np
|
|
from typing import List, Dict, Tuple, Optional, Union
|
|
import torch
|
|
from sklearn.metrics import average_precision_score, ndcg_score
|
|
from collections import defaultdict
|
|
import logging
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def compute_retrieval_metrics(
|
|
query_embeddings: np.ndarray,
|
|
passage_embeddings: np.ndarray,
|
|
labels: np.ndarray,
|
|
k_values: List[int] = [1, 5, 10, 20, 100],
|
|
return_all_scores: bool = False
|
|
) -> Dict[str, float]:
|
|
"""
|
|
Compute comprehensive retrieval metrics
|
|
|
|
Args:
|
|
query_embeddings: Query embeddings [num_queries, embedding_dim]
|
|
passage_embeddings: Passage embeddings [num_passages, embedding_dim]
|
|
labels: Binary relevance labels [num_queries, num_passages] or list of relevant indices
|
|
k_values: List of k values for metrics@k
|
|
return_all_scores: Whether to return all similarity scores
|
|
|
|
Returns:
|
|
Dictionary of metrics
|
|
"""
|
|
metrics = {}
|
|
|
|
# Compute similarity scores
|
|
scores = np.matmul(query_embeddings, passage_embeddings.T)
|
|
|
|
# Get number of queries
|
|
num_queries = query_embeddings.shape[0]
|
|
|
|
# Convert labels to binary matrix if needed
|
|
if labels.ndim == 1:
|
|
# Assume labels contains relevant passage indices
|
|
binary_labels = np.zeros((num_queries, passage_embeddings.shape[0]))
|
|
for i, idx in enumerate(labels):
|
|
if idx < passage_embeddings.shape[0]:
|
|
binary_labels[i, idx] = 1
|
|
labels = binary_labels
|
|
|
|
# Compute metrics for each k
|
|
for k in k_values:
|
|
metrics[f'recall@{k}'] = compute_recall_at_k(scores, labels, k)
|
|
metrics[f'precision@{k}'] = compute_precision_at_k(scores, labels, k)
|
|
metrics[f'map@{k}'] = compute_map_at_k(scores, labels, k)
|
|
metrics[f'mrr@{k}'] = compute_mrr_at_k(scores, labels, k)
|
|
metrics[f'ndcg@{k}'] = compute_ndcg_at_k(scores, labels, k)
|
|
|
|
# Overall metrics
|
|
metrics['map'] = compute_mean_average_precision(scores, labels)
|
|
metrics['mrr'] = compute_mean_reciprocal_rank(scores, labels)
|
|
|
|
if return_all_scores:
|
|
metrics['scores'] = scores
|
|
|
|
return metrics
|
|
|
|
|
|
def compute_retrieval_metrics_comprehensive(
|
|
query_embeddings: np.ndarray,
|
|
passage_embeddings: np.ndarray,
|
|
labels: np.ndarray,
|
|
k_values: List[int] = [1, 5, 10, 20, 100],
|
|
return_all_scores: bool = False,
|
|
compute_additional_metrics: bool = True
|
|
) -> Dict[str, float]:
|
|
"""
|
|
Compute comprehensive retrieval metrics with additional statistics
|
|
|
|
Args:
|
|
query_embeddings: Query embeddings [num_queries, embedding_dim]
|
|
passage_embeddings: Passage embeddings [num_passages, embedding_dim]
|
|
labels: Binary relevance labels [num_queries, num_passages] or list of relevant indices
|
|
k_values: List of k values for metrics@k
|
|
return_all_scores: Whether to return all similarity scores
|
|
compute_additional_metrics: Whether to compute additional metrics like success rate
|
|
|
|
Returns:
|
|
Dictionary of comprehensive metrics
|
|
"""
|
|
metrics = {}
|
|
|
|
# Validate inputs
|
|
if query_embeddings.shape[0] == 0 or passage_embeddings.shape[0] == 0:
|
|
raise ValueError("Empty embeddings provided")
|
|
|
|
# Normalize embeddings for cosine similarity
|
|
query_embeddings = query_embeddings / (np.linalg.norm(query_embeddings, axis=1, keepdims=True) + 1e-8)
|
|
passage_embeddings = passage_embeddings / (np.linalg.norm(passage_embeddings, axis=1, keepdims=True) + 1e-8)
|
|
|
|
# Compute similarity scores
|
|
scores = np.matmul(query_embeddings, passage_embeddings.T)
|
|
|
|
# Get number of queries
|
|
num_queries = query_embeddings.shape[0]
|
|
|
|
# Convert labels to binary matrix if needed
|
|
if labels.ndim == 1:
|
|
# Assume labels contains relevant passage indices
|
|
binary_labels = np.zeros((num_queries, passage_embeddings.shape[0]))
|
|
for i, idx in enumerate(labels):
|
|
if idx < passage_embeddings.shape[0]:
|
|
binary_labels[i, idx] = 1
|
|
labels = binary_labels
|
|
|
|
# Compute metrics for each k
|
|
for k in k_values:
|
|
metrics[f'recall@{k}'] = compute_recall_at_k(scores, labels, k)
|
|
metrics[f'precision@{k}'] = compute_precision_at_k(scores, labels, k)
|
|
metrics[f'map@{k}'] = compute_map_at_k(scores, labels, k)
|
|
metrics[f'mrr@{k}'] = compute_mrr_at_k(scores, labels, k)
|
|
metrics[f'ndcg@{k}'] = compute_ndcg_at_k(scores, labels, k)
|
|
|
|
if compute_additional_metrics:
|
|
# Success rate (queries with at least one relevant doc in top-k)
|
|
success_count = 0
|
|
for i in range(num_queries):
|
|
top_k_indices = np.argsort(scores[i])[::-1][:k]
|
|
if np.any(labels[i][top_k_indices] > 0):
|
|
success_count += 1
|
|
metrics[f'success@{k}'] = success_count / num_queries if num_queries > 0 else 0.0
|
|
|
|
# Overall metrics
|
|
metrics['map'] = compute_mean_average_precision(scores, labels)
|
|
metrics['mrr'] = compute_mean_reciprocal_rank(scores, labels)
|
|
|
|
if compute_additional_metrics:
|
|
# Average number of relevant documents per query
|
|
metrics['avg_relevant_per_query'] = np.mean(np.sum(labels > 0, axis=1))
|
|
|
|
# Coverage: fraction of relevant documents that appear in any top-k list
|
|
max_k = max(k_values)
|
|
covered_relevant = set()
|
|
for i in range(num_queries):
|
|
top_k_indices = np.argsort(scores[i])[::-1][:max_k]
|
|
relevant_indices = np.where(labels[i] > 0)[0]
|
|
for idx in relevant_indices:
|
|
if idx in top_k_indices:
|
|
covered_relevant.add((i, idx))
|
|
|
|
total_relevant = np.sum(labels > 0)
|
|
metrics['coverage'] = len(covered_relevant) / total_relevant if total_relevant > 0 else 0.0
|
|
|
|
if return_all_scores:
|
|
metrics['scores'] = scores
|
|
|
|
return metrics
|
|
|
|
|
|
def compute_recall_at_k(scores: np.ndarray, labels: np.ndarray, k: int) -> float:
|
|
"""
|
|
Compute Recall@K
|
|
|
|
Args:
|
|
scores: Similarity scores [num_queries, num_passages]
|
|
labels: Binary relevance labels
|
|
k: Top-k to consider
|
|
|
|
Returns:
|
|
Average recall@k
|
|
"""
|
|
num_queries = scores.shape[0]
|
|
recall_scores = []
|
|
|
|
for i in range(num_queries):
|
|
# Get top-k indices
|
|
top_k_indices = np.argsort(scores[i])[::-1][:k]
|
|
|
|
# Count relevant documents
|
|
num_relevant = labels[i].sum()
|
|
if num_relevant == 0:
|
|
continue
|
|
|
|
# Count relevant in top-k
|
|
num_relevant_in_topk = labels[i][top_k_indices].sum()
|
|
|
|
recall_scores.append(num_relevant_in_topk / num_relevant)
|
|
|
|
return float(np.mean(recall_scores)) if recall_scores else 0.0
|
|
|
|
|
|
def compute_precision_at_k(scores: np.ndarray, labels: np.ndarray, k: int) -> float:
|
|
"""
|
|
Compute Precision@K
|
|
"""
|
|
num_queries = scores.shape[0]
|
|
precision_scores = []
|
|
|
|
for i in range(num_queries):
|
|
# Get top-k indices
|
|
top_k_indices = np.argsort(scores[i])[::-1][:k]
|
|
|
|
# Count relevant in top-k
|
|
num_relevant_in_topk = labels[i][top_k_indices].sum()
|
|
|
|
precision_scores.append(num_relevant_in_topk / k)
|
|
|
|
return float(np.mean(precision_scores))
|
|
|
|
|
|
def compute_map_at_k(scores: np.ndarray, labels: np.ndarray, k: int) -> float:
|
|
"""
|
|
Compute Mean Average Precision@K
|
|
"""
|
|
num_queries = scores.shape[0]
|
|
ap_scores = []
|
|
|
|
for i in range(num_queries):
|
|
# Get top-k indices
|
|
top_k_indices = np.argsort(scores[i])[::-1][:k]
|
|
|
|
# Compute average precision
|
|
num_relevant = 0
|
|
sum_precision = 0
|
|
|
|
for j, idx in enumerate(top_k_indices):
|
|
if labels[i][idx] == 1:
|
|
num_relevant += 1
|
|
sum_precision += num_relevant / (j + 1)
|
|
|
|
if labels[i].sum() > 0:
|
|
ap_scores.append(sum_precision / min(labels[i].sum(), k))
|
|
|
|
return float(np.mean(ap_scores)) if ap_scores else 0.0
|
|
|
|
|
|
def compute_mrr_at_k(scores: np.ndarray, labels: np.ndarray, k: int) -> float:
|
|
"""
|
|
Compute Mean Reciprocal Rank@K with proper error handling
|
|
|
|
Args:
|
|
scores: Similarity scores [num_queries, num_passages]
|
|
labels: Binary relevance labels
|
|
k: Top-k to consider
|
|
|
|
Returns:
|
|
MRR@k score
|
|
"""
|
|
if scores.shape != labels.shape:
|
|
raise ValueError(f"Scores shape {scores.shape} != labels shape {labels.shape}")
|
|
|
|
num_queries = scores.shape[0]
|
|
rr_scores = []
|
|
|
|
for i in range(num_queries):
|
|
# Skip queries with no relevant documents
|
|
if labels[i].sum() == 0:
|
|
continue
|
|
|
|
# Get top-k indices
|
|
top_k_indices = np.argsort(scores[i])[::-1][:k]
|
|
|
|
# Find first relevant document
|
|
for j, idx in enumerate(top_k_indices):
|
|
if labels[i][idx] > 0: # Support graded relevance
|
|
rr_scores.append(1.0 / (j + 1))
|
|
break
|
|
else:
|
|
# No relevant document in top-k
|
|
rr_scores.append(0.0)
|
|
|
|
return float(np.mean(rr_scores)) if rr_scores else 0.0
|
|
|
|
|
|
def compute_ndcg_at_k(scores: np.ndarray, labels: np.ndarray, k: int) -> float:
|
|
"""
|
|
Compute Normalized Discounted Cumulative Gain@K with robust error handling
|
|
|
|
Args:
|
|
scores: Similarity scores [num_queries, num_passages]
|
|
labels: Binary or graded relevance labels
|
|
k: Top-k to consider
|
|
|
|
Returns:
|
|
NDCG@k score
|
|
"""
|
|
if scores.shape != labels.shape:
|
|
raise ValueError(f"Scores shape {scores.shape} != labels shape {labels.shape}")
|
|
|
|
num_queries = scores.shape[0]
|
|
ndcg_scores = []
|
|
|
|
for i in range(num_queries):
|
|
# Skip queries with no relevant documents
|
|
if labels[i].sum() == 0:
|
|
continue
|
|
|
|
# Get scores and labels for this query
|
|
query_scores = scores[i]
|
|
query_labels = labels[i]
|
|
|
|
# Get top-k indices by score
|
|
top_k_indices = np.argsort(query_scores)[::-1][:k]
|
|
|
|
# Compute DCG@k
|
|
dcg = 0.0
|
|
for j, idx in enumerate(top_k_indices):
|
|
if query_labels[idx] > 0:
|
|
# Use log2(j+2) as denominator (j+1 for rank, +1 to avoid log(1)=0)
|
|
dcg += query_labels[idx] / np.log2(j + 2)
|
|
|
|
# Compute ideal DCG@k
|
|
ideal_labels = np.sort(query_labels)[::-1][:k]
|
|
idcg = 0.0
|
|
for j, label in enumerate(ideal_labels):
|
|
if label > 0:
|
|
idcg += label / np.log2(j + 2)
|
|
|
|
# Compute NDCG
|
|
if idcg > 0:
|
|
ndcg_scores.append(dcg / idcg)
|
|
else:
|
|
ndcg_scores.append(0.0)
|
|
|
|
return float(np.mean(ndcg_scores)) if ndcg_scores else 0.0
|
|
|
|
|
|
def compute_mean_average_precision(scores: np.ndarray, labels: np.ndarray) -> float:
|
|
"""
|
|
Compute Mean Average Precision (MAP) over all positions
|
|
"""
|
|
num_queries = scores.shape[0]
|
|
ap_scores = []
|
|
|
|
for i in range(num_queries):
|
|
if labels[i].sum() > 0:
|
|
ap = average_precision_score(labels[i], scores[i])
|
|
ap_scores.append(ap)
|
|
|
|
return float(np.mean(ap_scores)) if ap_scores else 0.0
|
|
|
|
|
|
def compute_mean_reciprocal_rank(scores: np.ndarray, labels: np.ndarray) -> float:
|
|
"""
|
|
Compute Mean Reciprocal Rank (MRR) over all positions
|
|
"""
|
|
num_queries = scores.shape[0]
|
|
rr_scores = []
|
|
|
|
for i in range(num_queries):
|
|
# Get ranking
|
|
ranked_indices = np.argsort(scores[i])[::-1]
|
|
|
|
# Find first relevant document
|
|
for j, idx in enumerate(ranked_indices):
|
|
if labels[i][idx] == 1:
|
|
rr_scores.append(1.0 / (j + 1))
|
|
break
|
|
else:
|
|
rr_scores.append(0.0)
|
|
|
|
return float(np.mean(rr_scores))
|
|
|
|
|
|
def compute_reranker_metrics(
|
|
scores: Union[np.ndarray, torch.Tensor],
|
|
labels: Union[np.ndarray, torch.Tensor],
|
|
groups: Optional[List[int]] = None,
|
|
k_values: List[int] = [1, 5, 10]
|
|
) -> Dict[str, float]:
|
|
"""
|
|
Compute metrics specifically for reranker evaluation
|
|
|
|
Args:
|
|
scores: Relevance scores from reranker
|
|
labels: Binary relevance labels
|
|
groups: Group sizes for each query (for listwise evaluation)
|
|
k_values: List of k values
|
|
|
|
Returns:
|
|
Dictionary of reranker-specific metrics
|
|
"""
|
|
# Convert to numpy if needed
|
|
if isinstance(scores, torch.Tensor):
|
|
scores = scores.cpu().numpy()
|
|
if isinstance(labels, torch.Tensor):
|
|
labels = labels.cpu().numpy()
|
|
|
|
metrics = {}
|
|
|
|
if groups is None:
|
|
# Assume pairwise evaluation
|
|
# Binary classification metrics
|
|
predictions = (scores > 0.5).astype(int) # type: ignore[attr-defined]
|
|
|
|
# Accuracy
|
|
metrics['accuracy'] = np.mean(predictions == labels)
|
|
|
|
# Compute AUC if we have both positive and negative examples
|
|
if len(np.unique(labels)) > 1:
|
|
from sklearn.metrics import roc_auc_score
|
|
metrics['auc'] = roc_auc_score(labels, scores)
|
|
|
|
else:
|
|
# Listwise evaluation
|
|
start_idx = 0
|
|
mrr_scores = []
|
|
ndcg_scores = defaultdict(list)
|
|
|
|
for group_size in groups:
|
|
# Get scores and labels for this group
|
|
group_scores = scores[start_idx:start_idx + group_size]
|
|
group_labels = labels[start_idx:start_idx + group_size]
|
|
|
|
# Sort by scores
|
|
sorted_indices = np.argsort(group_scores)[::-1]
|
|
|
|
# MRR
|
|
for i, idx in enumerate(sorted_indices):
|
|
if group_labels[idx] == 1:
|
|
mrr_scores.append(1.0 / (i + 1))
|
|
break
|
|
else:
|
|
mrr_scores.append(0.0)
|
|
|
|
# NDCG@k
|
|
for k in k_values:
|
|
if group_size >= k:
|
|
try:
|
|
ndcg = ndcg_score(
|
|
group_labels.reshape(1, -1),
|
|
group_scores.reshape(1, -1),
|
|
k=k
|
|
)
|
|
ndcg_scores[k].append(ndcg)
|
|
except:
|
|
pass
|
|
|
|
start_idx += group_size
|
|
|
|
# Aggregate metrics
|
|
metrics['mrr'] = np.mean(mrr_scores)
|
|
|
|
for k in k_values:
|
|
if ndcg_scores[k]:
|
|
metrics[f'ndcg@{k}'] = np.mean(ndcg_scores[k])
|
|
|
|
return metrics
|
|
|
|
|
|
def evaluate_cross_lingual_retrieval(
|
|
query_embeddings: Dict[str, np.ndarray],
|
|
passage_embeddings: np.ndarray,
|
|
labels: Dict[str, np.ndarray],
|
|
languages: List[str] = ['en', 'zh']
|
|
) -> Dict[str, Dict[str, float]]:
|
|
"""
|
|
Evaluate cross-lingual retrieval performance
|
|
|
|
Args:
|
|
query_embeddings: Dict mapping language to query embeddings
|
|
passage_embeddings: Passage embeddings
|
|
labels: Dict mapping language to relevance labels
|
|
languages: List of languages to evaluate
|
|
|
|
Returns:
|
|
Nested dict of metrics per language
|
|
"""
|
|
results = {}
|
|
|
|
for lang in languages:
|
|
if lang in query_embeddings and lang in labels:
|
|
logger.info(f"Evaluating {lang} queries")
|
|
|
|
results[lang] = compute_retrieval_metrics(
|
|
query_embeddings[lang],
|
|
passage_embeddings,
|
|
labels[lang]
|
|
)
|
|
|
|
# Compute cross-lingual metrics if we have multiple languages
|
|
if len(languages) > 1 and all(lang in results for lang in languages):
|
|
# Average performance across languages
|
|
avg_metrics = {}
|
|
|
|
for metric in results[languages[0]].keys():
|
|
if metric != 'scores':
|
|
values = [results[lang][metric] for lang in languages]
|
|
avg_metrics[metric] = np.mean(values)
|
|
|
|
results['average'] = avg_metrics
|
|
|
|
return results
|
|
|
|
|
|
class MetricsTracker:
|
|
"""
|
|
Track and aggregate metrics during training
|
|
"""
|
|
|
|
def __init__(self):
|
|
self.metrics = defaultdict(list)
|
|
self.best_metrics = {}
|
|
self.best_steps = {}
|
|
|
|
def update(self, metrics: Dict[str, float], step: int):
|
|
"""Update metrics for current step"""
|
|
for key, value in metrics.items():
|
|
self.metrics[key].append((step, value))
|
|
|
|
# Track best metrics
|
|
if key not in self.best_metrics or value > self.best_metrics[key]:
|
|
self.best_metrics[key] = value
|
|
self.best_steps[key] = step
|
|
|
|
def get_current(self, metric: str) -> Optional[float]:
|
|
"""Get most recent value for a metric"""
|
|
if metric in self.metrics and self.metrics[metric]:
|
|
return self.metrics[metric][-1][1]
|
|
return None
|
|
|
|
def get_best(self, metric: str) -> Optional[Tuple[int, float]]:
|
|
"""Get best value and step for a metric"""
|
|
if metric in self.best_metrics:
|
|
return self.best_steps[metric], self.best_metrics[metric]
|
|
return None
|
|
|
|
def get_history(self, metric: str) -> List[Tuple[int, float]]:
|
|
"""Get full history for a metric"""
|
|
return self.metrics.get(metric, [])
|
|
|
|
def summary(self) -> Dict[str, Dict[str, float]]:
|
|
"""Get summary of all metrics"""
|
|
summary = {
|
|
'current': {},
|
|
'best': {},
|
|
'best_steps': {}
|
|
}
|
|
|
|
for metric in self.metrics:
|
|
summary['current'][metric] = self.get_current(metric)
|
|
summary['best'][metric] = self.best_metrics.get(metric)
|
|
summary['best_steps'][metric] = self.best_steps.get(metric)
|
|
|
|
return summary
|