2025-07-22 16:55:25 +08:00

543 lines
17 KiB
Python

import numpy as np
from typing import List, Dict, Tuple, Optional, Union
import torch
from sklearn.metrics import average_precision_score, ndcg_score
from collections import defaultdict
import logging
logger = logging.getLogger(__name__)
def compute_retrieval_metrics(
query_embeddings: np.ndarray,
passage_embeddings: np.ndarray,
labels: np.ndarray,
k_values: List[int] = [1, 5, 10, 20, 100],
return_all_scores: bool = False
) -> Dict[str, float]:
"""
Compute comprehensive retrieval metrics
Args:
query_embeddings: Query embeddings [num_queries, embedding_dim]
passage_embeddings: Passage embeddings [num_passages, embedding_dim]
labels: Binary relevance labels [num_queries, num_passages] or list of relevant indices
k_values: List of k values for metrics@k
return_all_scores: Whether to return all similarity scores
Returns:
Dictionary of metrics
"""
metrics = {}
# Compute similarity scores
scores = np.matmul(query_embeddings, passage_embeddings.T)
# Get number of queries
num_queries = query_embeddings.shape[0]
# Convert labels to binary matrix if needed
if labels.ndim == 1:
# Assume labels contains relevant passage indices
binary_labels = np.zeros((num_queries, passage_embeddings.shape[0]))
for i, idx in enumerate(labels):
if idx < passage_embeddings.shape[0]:
binary_labels[i, idx] = 1
labels = binary_labels
# Compute metrics for each k
for k in k_values:
metrics[f'recall@{k}'] = compute_recall_at_k(scores, labels, k)
metrics[f'precision@{k}'] = compute_precision_at_k(scores, labels, k)
metrics[f'map@{k}'] = compute_map_at_k(scores, labels, k)
metrics[f'mrr@{k}'] = compute_mrr_at_k(scores, labels, k)
metrics[f'ndcg@{k}'] = compute_ndcg_at_k(scores, labels, k)
# Overall metrics
metrics['map'] = compute_mean_average_precision(scores, labels)
metrics['mrr'] = compute_mean_reciprocal_rank(scores, labels)
if return_all_scores:
metrics['scores'] = scores
return metrics
def compute_retrieval_metrics_comprehensive(
query_embeddings: np.ndarray,
passage_embeddings: np.ndarray,
labels: np.ndarray,
k_values: List[int] = [1, 5, 10, 20, 100],
return_all_scores: bool = False,
compute_additional_metrics: bool = True
) -> Dict[str, float]:
"""
Compute comprehensive retrieval metrics with additional statistics
Args:
query_embeddings: Query embeddings [num_queries, embedding_dim]
passage_embeddings: Passage embeddings [num_passages, embedding_dim]
labels: Binary relevance labels [num_queries, num_passages] or list of relevant indices
k_values: List of k values for metrics@k
return_all_scores: Whether to return all similarity scores
compute_additional_metrics: Whether to compute additional metrics like success rate
Returns:
Dictionary of comprehensive metrics
"""
metrics = {}
# Validate inputs
if query_embeddings.shape[0] == 0 or passage_embeddings.shape[0] == 0:
raise ValueError("Empty embeddings provided")
# Normalize embeddings for cosine similarity
query_embeddings = query_embeddings / (np.linalg.norm(query_embeddings, axis=1, keepdims=True) + 1e-8)
passage_embeddings = passage_embeddings / (np.linalg.norm(passage_embeddings, axis=1, keepdims=True) + 1e-8)
# Compute similarity scores
scores = np.matmul(query_embeddings, passage_embeddings.T)
# Get number of queries
num_queries = query_embeddings.shape[0]
# Convert labels to binary matrix if needed
if labels.ndim == 1:
# Assume labels contains relevant passage indices
binary_labels = np.zeros((num_queries, passage_embeddings.shape[0]))
for i, idx in enumerate(labels):
if idx < passage_embeddings.shape[0]:
binary_labels[i, idx] = 1
labels = binary_labels
# Compute metrics for each k
for k in k_values:
metrics[f'recall@{k}'] = compute_recall_at_k(scores, labels, k)
metrics[f'precision@{k}'] = compute_precision_at_k(scores, labels, k)
metrics[f'map@{k}'] = compute_map_at_k(scores, labels, k)
metrics[f'mrr@{k}'] = compute_mrr_at_k(scores, labels, k)
metrics[f'ndcg@{k}'] = compute_ndcg_at_k(scores, labels, k)
if compute_additional_metrics:
# Success rate (queries with at least one relevant doc in top-k)
success_count = 0
for i in range(num_queries):
top_k_indices = np.argsort(scores[i])[::-1][:k]
if np.any(labels[i][top_k_indices] > 0):
success_count += 1
metrics[f'success@{k}'] = success_count / num_queries if num_queries > 0 else 0.0
# Overall metrics
metrics['map'] = compute_mean_average_precision(scores, labels)
metrics['mrr'] = compute_mean_reciprocal_rank(scores, labels)
if compute_additional_metrics:
# Average number of relevant documents per query
metrics['avg_relevant_per_query'] = np.mean(np.sum(labels > 0, axis=1))
# Coverage: fraction of relevant documents that appear in any top-k list
max_k = max(k_values)
covered_relevant = set()
for i in range(num_queries):
top_k_indices = np.argsort(scores[i])[::-1][:max_k]
relevant_indices = np.where(labels[i] > 0)[0]
for idx in relevant_indices:
if idx in top_k_indices:
covered_relevant.add((i, idx))
total_relevant = np.sum(labels > 0)
metrics['coverage'] = len(covered_relevant) / total_relevant if total_relevant > 0 else 0.0
if return_all_scores:
metrics['scores'] = scores
return metrics
def compute_recall_at_k(scores: np.ndarray, labels: np.ndarray, k: int) -> float:
"""
Compute Recall@K
Args:
scores: Similarity scores [num_queries, num_passages]
labels: Binary relevance labels
k: Top-k to consider
Returns:
Average recall@k
"""
num_queries = scores.shape[0]
recall_scores = []
for i in range(num_queries):
# Get top-k indices
top_k_indices = np.argsort(scores[i])[::-1][:k]
# Count relevant documents
num_relevant = labels[i].sum()
if num_relevant == 0:
continue
# Count relevant in top-k
num_relevant_in_topk = labels[i][top_k_indices].sum()
recall_scores.append(num_relevant_in_topk / num_relevant)
return float(np.mean(recall_scores)) if recall_scores else 0.0
def compute_precision_at_k(scores: np.ndarray, labels: np.ndarray, k: int) -> float:
"""
Compute Precision@K
"""
num_queries = scores.shape[0]
precision_scores = []
for i in range(num_queries):
# Get top-k indices
top_k_indices = np.argsort(scores[i])[::-1][:k]
# Count relevant in top-k
num_relevant_in_topk = labels[i][top_k_indices].sum()
precision_scores.append(num_relevant_in_topk / k)
return float(np.mean(precision_scores))
def compute_map_at_k(scores: np.ndarray, labels: np.ndarray, k: int) -> float:
"""
Compute Mean Average Precision@K
"""
num_queries = scores.shape[0]
ap_scores = []
for i in range(num_queries):
# Get top-k indices
top_k_indices = np.argsort(scores[i])[::-1][:k]
# Compute average precision
num_relevant = 0
sum_precision = 0
for j, idx in enumerate(top_k_indices):
if labels[i][idx] == 1:
num_relevant += 1
sum_precision += num_relevant / (j + 1)
if labels[i].sum() > 0:
ap_scores.append(sum_precision / min(labels[i].sum(), k))
return float(np.mean(ap_scores)) if ap_scores else 0.0
def compute_mrr_at_k(scores: np.ndarray, labels: np.ndarray, k: int) -> float:
"""
Compute Mean Reciprocal Rank@K with proper error handling
Args:
scores: Similarity scores [num_queries, num_passages]
labels: Binary relevance labels
k: Top-k to consider
Returns:
MRR@k score
"""
if scores.shape != labels.shape:
raise ValueError(f"Scores shape {scores.shape} != labels shape {labels.shape}")
num_queries = scores.shape[0]
rr_scores = []
for i in range(num_queries):
# Skip queries with no relevant documents
if labels[i].sum() == 0:
continue
# Get top-k indices
top_k_indices = np.argsort(scores[i])[::-1][:k]
# Find first relevant document
for j, idx in enumerate(top_k_indices):
if labels[i][idx] > 0: # Support graded relevance
rr_scores.append(1.0 / (j + 1))
break
else:
# No relevant document in top-k
rr_scores.append(0.0)
return float(np.mean(rr_scores)) if rr_scores else 0.0
def compute_ndcg_at_k(scores: np.ndarray, labels: np.ndarray, k: int) -> float:
"""
Compute Normalized Discounted Cumulative Gain@K with robust error handling
Args:
scores: Similarity scores [num_queries, num_passages]
labels: Binary or graded relevance labels
k: Top-k to consider
Returns:
NDCG@k score
"""
if scores.shape != labels.shape:
raise ValueError(f"Scores shape {scores.shape} != labels shape {labels.shape}")
num_queries = scores.shape[0]
ndcg_scores = []
for i in range(num_queries):
# Skip queries with no relevant documents
if labels[i].sum() == 0:
continue
# Get scores and labels for this query
query_scores = scores[i]
query_labels = labels[i]
# Get top-k indices by score
top_k_indices = np.argsort(query_scores)[::-1][:k]
# Compute DCG@k
dcg = 0.0
for j, idx in enumerate(top_k_indices):
if query_labels[idx] > 0:
# Use log2(j+2) as denominator (j+1 for rank, +1 to avoid log(1)=0)
dcg += query_labels[idx] / np.log2(j + 2)
# Compute ideal DCG@k
ideal_labels = np.sort(query_labels)[::-1][:k]
idcg = 0.0
for j, label in enumerate(ideal_labels):
if label > 0:
idcg += label / np.log2(j + 2)
# Compute NDCG
if idcg > 0:
ndcg_scores.append(dcg / idcg)
else:
ndcg_scores.append(0.0)
return float(np.mean(ndcg_scores)) if ndcg_scores else 0.0
def compute_mean_average_precision(scores: np.ndarray, labels: np.ndarray) -> float:
"""
Compute Mean Average Precision (MAP) over all positions
"""
num_queries = scores.shape[0]
ap_scores = []
for i in range(num_queries):
if labels[i].sum() > 0:
ap = average_precision_score(labels[i], scores[i])
ap_scores.append(ap)
return float(np.mean(ap_scores)) if ap_scores else 0.0
def compute_mean_reciprocal_rank(scores: np.ndarray, labels: np.ndarray) -> float:
"""
Compute Mean Reciprocal Rank (MRR) over all positions
"""
num_queries = scores.shape[0]
rr_scores = []
for i in range(num_queries):
# Get ranking
ranked_indices = np.argsort(scores[i])[::-1]
# Find first relevant document
for j, idx in enumerate(ranked_indices):
if labels[i][idx] == 1:
rr_scores.append(1.0 / (j + 1))
break
else:
rr_scores.append(0.0)
return float(np.mean(rr_scores))
def compute_reranker_metrics(
scores: Union[np.ndarray, torch.Tensor],
labels: Union[np.ndarray, torch.Tensor],
groups: Optional[List[int]] = None,
k_values: List[int] = [1, 5, 10]
) -> Dict[str, float]:
"""
Compute metrics specifically for reranker evaluation
Args:
scores: Relevance scores from reranker
labels: Binary relevance labels
groups: Group sizes for each query (for listwise evaluation)
k_values: List of k values
Returns:
Dictionary of reranker-specific metrics
"""
# Convert to numpy if needed
if isinstance(scores, torch.Tensor):
scores = scores.cpu().numpy()
if isinstance(labels, torch.Tensor):
labels = labels.cpu().numpy()
metrics = {}
if groups is None:
# Assume pairwise evaluation
# Binary classification metrics
predictions = (scores > 0.5).astype(int) # type: ignore[attr-defined]
# Accuracy
metrics['accuracy'] = np.mean(predictions == labels)
# Compute AUC if we have both positive and negative examples
if len(np.unique(labels)) > 1:
from sklearn.metrics import roc_auc_score
metrics['auc'] = roc_auc_score(labels, scores)
else:
# Listwise evaluation
start_idx = 0
mrr_scores = []
ndcg_scores = defaultdict(list)
for group_size in groups:
# Get scores and labels for this group
group_scores = scores[start_idx:start_idx + group_size]
group_labels = labels[start_idx:start_idx + group_size]
# Sort by scores
sorted_indices = np.argsort(group_scores)[::-1]
# MRR
for i, idx in enumerate(sorted_indices):
if group_labels[idx] == 1:
mrr_scores.append(1.0 / (i + 1))
break
else:
mrr_scores.append(0.0)
# NDCG@k
for k in k_values:
if group_size >= k:
try:
ndcg = ndcg_score(
group_labels.reshape(1, -1),
group_scores.reshape(1, -1),
k=k
)
ndcg_scores[k].append(ndcg)
except:
pass
start_idx += group_size
# Aggregate metrics
metrics['mrr'] = np.mean(mrr_scores)
for k in k_values:
if ndcg_scores[k]:
metrics[f'ndcg@{k}'] = np.mean(ndcg_scores[k])
return metrics
def evaluate_cross_lingual_retrieval(
query_embeddings: Dict[str, np.ndarray],
passage_embeddings: np.ndarray,
labels: Dict[str, np.ndarray],
languages: List[str] = ['en', 'zh']
) -> Dict[str, Dict[str, float]]:
"""
Evaluate cross-lingual retrieval performance
Args:
query_embeddings: Dict mapping language to query embeddings
passage_embeddings: Passage embeddings
labels: Dict mapping language to relevance labels
languages: List of languages to evaluate
Returns:
Nested dict of metrics per language
"""
results = {}
for lang in languages:
if lang in query_embeddings and lang in labels:
logger.info(f"Evaluating {lang} queries")
results[lang] = compute_retrieval_metrics(
query_embeddings[lang],
passage_embeddings,
labels[lang]
)
# Compute cross-lingual metrics if we have multiple languages
if len(languages) > 1 and all(lang in results for lang in languages):
# Average performance across languages
avg_metrics = {}
for metric in results[languages[0]].keys():
if metric != 'scores':
values = [results[lang][metric] for lang in languages]
avg_metrics[metric] = np.mean(values)
results['average'] = avg_metrics
return results
class MetricsTracker:
"""
Track and aggregate metrics during training
"""
def __init__(self):
self.metrics = defaultdict(list)
self.best_metrics = {}
self.best_steps = {}
def update(self, metrics: Dict[str, float], step: int):
"""Update metrics for current step"""
for key, value in metrics.items():
self.metrics[key].append((step, value))
# Track best metrics
if key not in self.best_metrics or value > self.best_metrics[key]:
self.best_metrics[key] = value
self.best_steps[key] = step
def get_current(self, metric: str) -> Optional[float]:
"""Get most recent value for a metric"""
if metric in self.metrics and self.metrics[metric]:
return self.metrics[metric][-1][1]
return None
def get_best(self, metric: str) -> Optional[Tuple[int, float]]:
"""Get best value and step for a metric"""
if metric in self.best_metrics:
return self.best_steps[metric], self.best_metrics[metric]
return None
def get_history(self, metric: str) -> List[Tuple[int, float]]:
"""Get full history for a metric"""
return self.metrics.get(metric, [])
def summary(self) -> Dict[str, Dict[str, float]]:
"""Get summary of all metrics"""
summary = {
'current': {},
'best': {},
'best_steps': {}
}
for metric in self.metrics:
summary['current'][metric] = self.get_current(metric)
summary['best'][metric] = self.best_metrics.get(metric)
summary['best_steps'][metric] = self.best_steps.get(metric)
return summary