import numpy as np from typing import List, Dict, Tuple, Optional, Union import torch from sklearn.metrics import average_precision_score, ndcg_score from collections import defaultdict import logging logger = logging.getLogger(__name__) def compute_retrieval_metrics( query_embeddings: np.ndarray, passage_embeddings: np.ndarray, labels: np.ndarray, k_values: List[int] = [1, 5, 10, 20, 100], return_all_scores: bool = False ) -> Dict[str, float]: """ Compute comprehensive retrieval metrics Args: query_embeddings: Query embeddings [num_queries, embedding_dim] passage_embeddings: Passage embeddings [num_passages, embedding_dim] labels: Binary relevance labels [num_queries, num_passages] or list of relevant indices k_values: List of k values for metrics@k return_all_scores: Whether to return all similarity scores Returns: Dictionary of metrics """ metrics = {} # Compute similarity scores scores = np.matmul(query_embeddings, passage_embeddings.T) # Get number of queries num_queries = query_embeddings.shape[0] # Convert labels to binary matrix if needed if labels.ndim == 1: # Assume labels contains relevant passage indices binary_labels = np.zeros((num_queries, passage_embeddings.shape[0])) for i, idx in enumerate(labels): if idx < passage_embeddings.shape[0]: binary_labels[i, idx] = 1 labels = binary_labels # Compute metrics for each k for k in k_values: metrics[f'recall@{k}'] = compute_recall_at_k(scores, labels, k) metrics[f'precision@{k}'] = compute_precision_at_k(scores, labels, k) metrics[f'map@{k}'] = compute_map_at_k(scores, labels, k) metrics[f'mrr@{k}'] = compute_mrr_at_k(scores, labels, k) metrics[f'ndcg@{k}'] = compute_ndcg_at_k(scores, labels, k) # Overall metrics metrics['map'] = compute_mean_average_precision(scores, labels) metrics['mrr'] = compute_mean_reciprocal_rank(scores, labels) if return_all_scores: metrics['scores'] = scores return metrics def compute_retrieval_metrics_comprehensive( query_embeddings: np.ndarray, passage_embeddings: np.ndarray, labels: np.ndarray, k_values: List[int] = [1, 5, 10, 20, 100], return_all_scores: bool = False, compute_additional_metrics: bool = True ) -> Dict[str, float]: """ Compute comprehensive retrieval metrics with additional statistics Args: query_embeddings: Query embeddings [num_queries, embedding_dim] passage_embeddings: Passage embeddings [num_passages, embedding_dim] labels: Binary relevance labels [num_queries, num_passages] or list of relevant indices k_values: List of k values for metrics@k return_all_scores: Whether to return all similarity scores compute_additional_metrics: Whether to compute additional metrics like success rate Returns: Dictionary of comprehensive metrics """ metrics = {} # Validate inputs if query_embeddings.shape[0] == 0 or passage_embeddings.shape[0] == 0: raise ValueError("Empty embeddings provided") # Normalize embeddings for cosine similarity query_embeddings = query_embeddings / (np.linalg.norm(query_embeddings, axis=1, keepdims=True) + 1e-8) passage_embeddings = passage_embeddings / (np.linalg.norm(passage_embeddings, axis=1, keepdims=True) + 1e-8) # Compute similarity scores scores = np.matmul(query_embeddings, passage_embeddings.T) # Get number of queries num_queries = query_embeddings.shape[0] # Convert labels to binary matrix if needed if labels.ndim == 1: # Assume labels contains relevant passage indices binary_labels = np.zeros((num_queries, passage_embeddings.shape[0])) for i, idx in enumerate(labels): if idx < passage_embeddings.shape[0]: binary_labels[i, idx] = 1 labels = binary_labels # Compute metrics for each k for k in k_values: metrics[f'recall@{k}'] = compute_recall_at_k(scores, labels, k) metrics[f'precision@{k}'] = compute_precision_at_k(scores, labels, k) metrics[f'map@{k}'] = compute_map_at_k(scores, labels, k) metrics[f'mrr@{k}'] = compute_mrr_at_k(scores, labels, k) metrics[f'ndcg@{k}'] = compute_ndcg_at_k(scores, labels, k) if compute_additional_metrics: # Success rate (queries with at least one relevant doc in top-k) success_count = 0 for i in range(num_queries): top_k_indices = np.argsort(scores[i])[::-1][:k] if np.any(labels[i][top_k_indices] > 0): success_count += 1 metrics[f'success@{k}'] = success_count / num_queries if num_queries > 0 else 0.0 # Overall metrics metrics['map'] = compute_mean_average_precision(scores, labels) metrics['mrr'] = compute_mean_reciprocal_rank(scores, labels) if compute_additional_metrics: # Average number of relevant documents per query metrics['avg_relevant_per_query'] = np.mean(np.sum(labels > 0, axis=1)) # Coverage: fraction of relevant documents that appear in any top-k list max_k = max(k_values) covered_relevant = set() for i in range(num_queries): top_k_indices = np.argsort(scores[i])[::-1][:max_k] relevant_indices = np.where(labels[i] > 0)[0] for idx in relevant_indices: if idx in top_k_indices: covered_relevant.add((i, idx)) total_relevant = np.sum(labels > 0) metrics['coverage'] = len(covered_relevant) / total_relevant if total_relevant > 0 else 0.0 if return_all_scores: metrics['scores'] = scores return metrics def compute_recall_at_k(scores: np.ndarray, labels: np.ndarray, k: int) -> float: """ Compute Recall@K Args: scores: Similarity scores [num_queries, num_passages] labels: Binary relevance labels k: Top-k to consider Returns: Average recall@k """ num_queries = scores.shape[0] recall_scores = [] for i in range(num_queries): # Get top-k indices top_k_indices = np.argsort(scores[i])[::-1][:k] # Count relevant documents num_relevant = labels[i].sum() if num_relevant == 0: continue # Count relevant in top-k num_relevant_in_topk = labels[i][top_k_indices].sum() recall_scores.append(num_relevant_in_topk / num_relevant) return float(np.mean(recall_scores)) if recall_scores else 0.0 def compute_precision_at_k(scores: np.ndarray, labels: np.ndarray, k: int) -> float: """ Compute Precision@K """ num_queries = scores.shape[0] precision_scores = [] for i in range(num_queries): # Get top-k indices top_k_indices = np.argsort(scores[i])[::-1][:k] # Count relevant in top-k num_relevant_in_topk = labels[i][top_k_indices].sum() precision_scores.append(num_relevant_in_topk / k) return float(np.mean(precision_scores)) def compute_map_at_k(scores: np.ndarray, labels: np.ndarray, k: int) -> float: """ Compute Mean Average Precision@K """ num_queries = scores.shape[0] ap_scores = [] for i in range(num_queries): # Get top-k indices top_k_indices = np.argsort(scores[i])[::-1][:k] # Compute average precision num_relevant = 0 sum_precision = 0 for j, idx in enumerate(top_k_indices): if labels[i][idx] == 1: num_relevant += 1 sum_precision += num_relevant / (j + 1) if labels[i].sum() > 0: ap_scores.append(sum_precision / min(labels[i].sum(), k)) return float(np.mean(ap_scores)) if ap_scores else 0.0 def compute_mrr_at_k(scores: np.ndarray, labels: np.ndarray, k: int) -> float: """ Compute Mean Reciprocal Rank@K with proper error handling Args: scores: Similarity scores [num_queries, num_passages] labels: Binary relevance labels k: Top-k to consider Returns: MRR@k score """ if scores.shape != labels.shape: raise ValueError(f"Scores shape {scores.shape} != labels shape {labels.shape}") num_queries = scores.shape[0] rr_scores = [] for i in range(num_queries): # Skip queries with no relevant documents if labels[i].sum() == 0: continue # Get top-k indices top_k_indices = np.argsort(scores[i])[::-1][:k] # Find first relevant document for j, idx in enumerate(top_k_indices): if labels[i][idx] > 0: # Support graded relevance rr_scores.append(1.0 / (j + 1)) break else: # No relevant document in top-k rr_scores.append(0.0) return float(np.mean(rr_scores)) if rr_scores else 0.0 def compute_ndcg_at_k(scores: np.ndarray, labels: np.ndarray, k: int) -> float: """ Compute Normalized Discounted Cumulative Gain@K with robust error handling Args: scores: Similarity scores [num_queries, num_passages] labels: Binary or graded relevance labels k: Top-k to consider Returns: NDCG@k score """ if scores.shape != labels.shape: raise ValueError(f"Scores shape {scores.shape} != labels shape {labels.shape}") num_queries = scores.shape[0] ndcg_scores = [] for i in range(num_queries): # Skip queries with no relevant documents if labels[i].sum() == 0: continue # Get scores and labels for this query query_scores = scores[i] query_labels = labels[i] # Get top-k indices by score top_k_indices = np.argsort(query_scores)[::-1][:k] # Compute DCG@k dcg = 0.0 for j, idx in enumerate(top_k_indices): if query_labels[idx] > 0: # Use log2(j+2) as denominator (j+1 for rank, +1 to avoid log(1)=0) dcg += query_labels[idx] / np.log2(j + 2) # Compute ideal DCG@k ideal_labels = np.sort(query_labels)[::-1][:k] idcg = 0.0 for j, label in enumerate(ideal_labels): if label > 0: idcg += label / np.log2(j + 2) # Compute NDCG if idcg > 0: ndcg_scores.append(dcg / idcg) else: ndcg_scores.append(0.0) return float(np.mean(ndcg_scores)) if ndcg_scores else 0.0 def compute_mean_average_precision(scores: np.ndarray, labels: np.ndarray) -> float: """ Compute Mean Average Precision (MAP) over all positions """ num_queries = scores.shape[0] ap_scores = [] for i in range(num_queries): if labels[i].sum() > 0: ap = average_precision_score(labels[i], scores[i]) ap_scores.append(ap) return float(np.mean(ap_scores)) if ap_scores else 0.0 def compute_mean_reciprocal_rank(scores: np.ndarray, labels: np.ndarray) -> float: """ Compute Mean Reciprocal Rank (MRR) over all positions """ num_queries = scores.shape[0] rr_scores = [] for i in range(num_queries): # Get ranking ranked_indices = np.argsort(scores[i])[::-1] # Find first relevant document for j, idx in enumerate(ranked_indices): if labels[i][idx] == 1: rr_scores.append(1.0 / (j + 1)) break else: rr_scores.append(0.0) return float(np.mean(rr_scores)) def compute_reranker_metrics( scores: Union[np.ndarray, torch.Tensor], labels: Union[np.ndarray, torch.Tensor], groups: Optional[List[int]] = None, k_values: List[int] = [1, 5, 10] ) -> Dict[str, float]: """ Compute metrics specifically for reranker evaluation Args: scores: Relevance scores from reranker labels: Binary relevance labels groups: Group sizes for each query (for listwise evaluation) k_values: List of k values Returns: Dictionary of reranker-specific metrics """ # Convert to numpy if needed if isinstance(scores, torch.Tensor): scores = scores.cpu().numpy() if isinstance(labels, torch.Tensor): labels = labels.cpu().numpy() metrics = {} if groups is None: # Assume pairwise evaluation # Binary classification metrics predictions = (scores > 0.5).astype(int) # type: ignore[attr-defined] # Accuracy metrics['accuracy'] = np.mean(predictions == labels) # Compute AUC if we have both positive and negative examples if len(np.unique(labels)) > 1: from sklearn.metrics import roc_auc_score metrics['auc'] = roc_auc_score(labels, scores) else: # Listwise evaluation start_idx = 0 mrr_scores = [] ndcg_scores = defaultdict(list) for group_size in groups: # Get scores and labels for this group group_scores = scores[start_idx:start_idx + group_size] group_labels = labels[start_idx:start_idx + group_size] # Sort by scores sorted_indices = np.argsort(group_scores)[::-1] # MRR for i, idx in enumerate(sorted_indices): if group_labels[idx] == 1: mrr_scores.append(1.0 / (i + 1)) break else: mrr_scores.append(0.0) # NDCG@k for k in k_values: if group_size >= k: try: ndcg = ndcg_score( group_labels.reshape(1, -1), group_scores.reshape(1, -1), k=k ) ndcg_scores[k].append(ndcg) except: pass start_idx += group_size # Aggregate metrics metrics['mrr'] = np.mean(mrr_scores) for k in k_values: if ndcg_scores[k]: metrics[f'ndcg@{k}'] = np.mean(ndcg_scores[k]) return metrics def evaluate_cross_lingual_retrieval( query_embeddings: Dict[str, np.ndarray], passage_embeddings: np.ndarray, labels: Dict[str, np.ndarray], languages: List[str] = ['en', 'zh'] ) -> Dict[str, Dict[str, float]]: """ Evaluate cross-lingual retrieval performance Args: query_embeddings: Dict mapping language to query embeddings passage_embeddings: Passage embeddings labels: Dict mapping language to relevance labels languages: List of languages to evaluate Returns: Nested dict of metrics per language """ results = {} for lang in languages: if lang in query_embeddings and lang in labels: logger.info(f"Evaluating {lang} queries") results[lang] = compute_retrieval_metrics( query_embeddings[lang], passage_embeddings, labels[lang] ) # Compute cross-lingual metrics if we have multiple languages if len(languages) > 1 and all(lang in results for lang in languages): # Average performance across languages avg_metrics = {} for metric in results[languages[0]].keys(): if metric != 'scores': values = [results[lang][metric] for lang in languages] avg_metrics[metric] = np.mean(values) results['average'] = avg_metrics return results class MetricsTracker: """ Track and aggregate metrics during training """ def __init__(self): self.metrics = defaultdict(list) self.best_metrics = {} self.best_steps = {} def update(self, metrics: Dict[str, float], step: int): """Update metrics for current step""" for key, value in metrics.items(): self.metrics[key].append((step, value)) # Track best metrics if key not in self.best_metrics or value > self.best_metrics[key]: self.best_metrics[key] = value self.best_steps[key] = step def get_current(self, metric: str) -> Optional[float]: """Get most recent value for a metric""" if metric in self.metrics and self.metrics[metric]: return self.metrics[metric][-1][1] return None def get_best(self, metric: str) -> Optional[Tuple[int, float]]: """Get best value and step for a metric""" if metric in self.best_metrics: return self.best_steps[metric], self.best_metrics[metric] return None def get_history(self, metric: str) -> List[Tuple[int, float]]: """Get full history for a metric""" return self.metrics.get(metric, []) def summary(self) -> Dict[str, Dict[str, float]]: """Get summary of all metrics""" summary = { 'current': {}, 'best': {}, 'best_steps': {} } for metric in self.metrics: summary['current'][metric] = self.get_current(metric) summary['best'][metric] = self.best_metrics.get(metric) summary['best_steps'][metric] = self.best_steps.get(metric) return summary