import torch import numpy as np from typing import List, Dict, Optional, Tuple, Union from tqdm import tqdm import logging from torch.utils.data import DataLoader import json import os from collections import defaultdict from models.bge_m3 import BGEM3Model from models.bge_reranker import BGERerankerModel from .metrics import ( compute_retrieval_metrics, compute_reranker_metrics, evaluate_cross_lingual_retrieval, MetricsTracker ) logger = logging.getLogger(__name__) class RetrievalEvaluator: """ Comprehensive evaluator for retrieval models Notes: - All parameters (model, tokenizer, device, metrics, k_values, etc.) should be passed from config-driven scripts. - This class does not load config directly; pass config values from your main script for consistency. """ def __init__( self, model: Union[BGEM3Model, BGERerankerModel], tokenizer, device: str = 'cuda' if torch.cuda.is_available() else 'cpu', metrics: List[str] = ['recall', 'precision', 'map', 'mrr', 'ndcg'], k_values: List[int] = [1, 5, 10, 20, 100] ): self.model = model self.tokenizer = tokenizer self.device = device self.metrics = metrics self.k_values = k_values # Move model to device self.model.to(device) self.model.eval() # Metrics tracker self.tracker = MetricsTracker() def evaluate( self, dataloader: DataLoader, desc: str = "Evaluating", return_embeddings: bool = False ) -> Dict[str, Union[float, list, np.ndarray]]: """ Evaluate model on dataset Args: dataloader: DataLoader with evaluation data desc: Description for progress bar return_embeddings: Whether to return computed embeddings Returns: Dictionary of evaluation metrics """ self.model.eval() all_query_embeddings = [] all_passage_embeddings = [] all_labels = [] all_queries = [] all_passages = [] with torch.no_grad(): for batch in tqdm(dataloader, desc=desc): # Move batch to device batch = {k: v.to(self.device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()} # Get embeddings based on model type if isinstance(self.model, BGEM3Model): # Bi-encoder evaluation query_embeddings, passage_embeddings = self._evaluate_biencoder_batch(batch) all_query_embeddings.append(query_embeddings.cpu()) all_passage_embeddings.append(passage_embeddings.cpu()) else: # Cross-encoder evaluation scores = self._evaluate_crossencoder_batch(batch) # For cross-encoder, we store scores instead of embeddings all_query_embeddings.append(scores.cpu()) # Collect labels if 'labels' in batch: all_labels.append(batch['labels'].cpu()) # Collect text if available if 'query' in batch: all_queries.extend(batch['query']) if 'passages' in batch: all_passages.extend(batch['passages']) # Concatenate results if isinstance(self.model, BGEM3Model): all_query_embeddings = torch.cat(all_query_embeddings, dim=0).numpy() all_passage_embeddings = torch.cat(all_passage_embeddings, dim=0).numpy() all_labels = torch.cat(all_labels, dim=0).numpy() if all_labels else None # Compute retrieval metrics if all_labels is not None: metrics = compute_retrieval_metrics( all_query_embeddings, all_passage_embeddings, all_labels, k_values=self.k_values ) else: metrics = {} else: # Cross-encoder metrics all_scores = torch.cat(all_query_embeddings, dim=0).numpy() all_labels = torch.cat(all_labels, dim=0).numpy() if all_labels else None if all_labels is not None: # Determine if we have groups (listwise) or pairs if hasattr(dataloader.dataset, 'train_group_size'): groups = [getattr(dataloader.dataset, 'train_group_size')] * len(dataloader.dataset) # type: ignore[arg-type] else: groups = None metrics = compute_reranker_metrics( all_scores, all_labels, groups=groups, k_values=self.k_values ) else: metrics = {} # Update tracker self.tracker.update(metrics, step=0) # Return embeddings if requested if return_embeddings and isinstance(self.model, BGEM3Model): metrics['query_embeddings'] = all_query_embeddings # type: ignore[assignment] metrics['passage_embeddings'] = all_passage_embeddings # type: ignore[assignment] metrics['queries'] = all_queries # type: ignore[assignment] metrics['passages'] = all_passages # type: ignore[assignment] return metrics # type: ignore[return-value] def _evaluate_biencoder_batch(self, batch: Dict) -> Tuple[torch.Tensor, torch.Tensor]: """Evaluate a batch for bi-encoder model""" # Get query embeddings query_outputs = self.model( input_ids=batch['query_input_ids'], attention_mask=batch['query_attention_mask'], return_dense=True ) query_embeddings = query_outputs['dense'] # Get passage embeddings # Handle multiple passages per query if batch['passage_input_ids'].dim() == 3: # Flatten batch dimension batch_size, num_passages, seq_len = batch['passage_input_ids'].shape passage_input_ids = batch['passage_input_ids'].view(-1, seq_len) passage_attention_mask = batch['passage_attention_mask'].view(-1, seq_len) else: passage_input_ids = batch['passage_input_ids'] passage_attention_mask = batch['passage_attention_mask'] passage_outputs = self.model( input_ids=passage_input_ids, attention_mask=passage_attention_mask, return_dense=True ) passage_embeddings = passage_outputs['dense'] return query_embeddings, passage_embeddings def _evaluate_crossencoder_batch(self, batch: Dict) -> torch.Tensor: """Evaluate a batch for cross-encoder model""" outputs = self.model( input_ids=batch['input_ids'], attention_mask=batch['attention_mask'] ) return outputs['logits'] def evaluate_retrieval_pipeline( self, queries: List[str], corpus: List[str], relevant_docs: Dict[str, List[int]], batch_size: int = 32 ) -> Dict[str, float]: """ Evaluate full retrieval pipeline Args: queries: List of queries corpus: List of all documents relevant_docs: Dict mapping query to relevant doc indices batch_size: Batch size for encoding Returns: Evaluation metrics """ logger.info(f"Evaluating retrieval pipeline with {len(queries)} queries over {len(corpus)} documents") # Encode all documents logger.info("Encoding corpus...") corpus_embeddings = self._encode_texts(corpus, batch_size, desc="Encoding corpus") # Encode queries and evaluate all_metrics = defaultdict(list) for i in tqdm(range(0, len(queries), batch_size), desc="Evaluating queries"): batch_queries = queries[i:i + batch_size] batch_relevant = [relevant_docs.get(q, []) for q in batch_queries] # Encode queries query_embeddings = self._encode_texts(batch_queries, batch_size, show_progress=False) # Create labels matrix labels = np.zeros((len(batch_queries), len(corpus))) for j, relevant_indices in enumerate(batch_relevant): for idx in relevant_indices: if idx < len(corpus): labels[j, idx] = 1 # Compute metrics batch_metrics = compute_retrieval_metrics( query_embeddings, corpus_embeddings, labels, k_values=self.k_values ) # Accumulate metrics for metric, value in batch_metrics.items(): if metric != 'scores': all_metrics[metric].append(value) # Average metrics final_metrics = { metric: np.mean(values) for metric, values in all_metrics.items() } return final_metrics # type: ignore[return-value] def _encode_texts( self, texts: List[str], batch_size: int, desc: str = "Encoding", show_progress: bool = True ) -> np.ndarray: """Encode texts to embeddings""" if isinstance(self.model, BGEM3Model): # Use model's encode method embeddings = self.model.encode( texts, batch_size=batch_size, device=self.device ) if isinstance(embeddings, torch.Tensor): embeddings = embeddings.cpu().numpy() else: # For cross-encoder, this doesn't make sense raise ValueError("Cannot encode texts with cross-encoder model") return embeddings # type: ignore[return-value] def evaluate_cross_lingual( self, queries: Dict[str, List[str]], corpus: List[str], relevant_docs: Dict[str, Dict[str, List[int]]], batch_size: int = 32 ) -> Dict[str, Dict[str, float]]: """ Evaluate cross-lingual retrieval Args: queries: Dict mapping language to queries corpus: List of documents (assumed to be in source language) relevant_docs: Nested dict of language -> query -> relevant indices batch_size: Batch size Returns: Metrics per language """ # Encode corpus once corpus_embeddings = self._encode_texts(corpus, batch_size, desc="Encoding corpus") # Evaluate each language results = {} for lang, lang_queries in queries.items(): logger.info(f"Evaluating {lang} queries...") # Encode queries query_embeddings = self._encode_texts(lang_queries, batch_size, desc=f"Encoding {lang} queries") # Create labels labels = np.zeros((len(lang_queries), len(corpus))) for i, query in enumerate(lang_queries): if query in relevant_docs.get(lang, {}): for idx in relevant_docs[lang][query]: if idx < len(corpus): labels[i, idx] = 1 # Compute metrics metrics = compute_retrieval_metrics( query_embeddings, corpus_embeddings, labels, k_values=self.k_values ) results[lang] = metrics # Compute average across languages if len(results) > 1: avg_metrics = {} for metric in results[list(results.keys())[0]]: if metric != 'scores': values = [results[lang][metric] for lang in results] avg_metrics[metric] = np.mean(values) results['average'] = avg_metrics return results def save_evaluation_results( self, metrics: Dict[str, float], output_dir: str, prefix: str = "eval" ): """Save evaluation results""" os.makedirs(output_dir, exist_ok=True) # Save metrics metrics_file = os.path.join(output_dir, f"{prefix}_metrics.json") with open(metrics_file, 'w') as f: json.dump(metrics, f, indent=2) # Save summary summary_file = os.path.join(output_dir, f"{prefix}_summary.txt") with open(summary_file, 'w') as f: f.write(f"Evaluation Results - {prefix}\n") f.write("=" * 50 + "\n\n") for metric, value in sorted(metrics.items()): if isinstance(value, float): f.write(f"{metric:.<30} {value:.4f}\n") logger.info(f"Saved evaluation results to {output_dir}") class InteractiveEvaluator: """ Interactive evaluation for qualitative analysis """ def __init__( self, retriever: BGEM3Model, reranker: Optional[BGERerankerModel], tokenizer, corpus: List[str], device: str = 'cuda' if torch.cuda.is_available() else 'cpu' ): self.retriever = retriever self.reranker = reranker self.tokenizer = tokenizer self.corpus = corpus self.device = device # Pre-encode corpus logger.info("Pre-encoding corpus...") self.corpus_embeddings = self._encode_corpus() logger.info(f"Encoded {len(self.corpus)} documents") def _encode_corpus(self) -> np.ndarray: """Pre-encode all corpus documents""" embeddings = self.retriever.encode( self.corpus, batch_size=32, device=self.device ) if isinstance(embeddings, torch.Tensor): embeddings = embeddings.cpu().numpy() return embeddings # type: ignore[return-value] def search( self, query: str, top_k: int = 10, rerank: bool = True ) -> List[Tuple[int, float, str]]: """ Search for query in corpus Args: query: Search query top_k: Number of results to return rerank: Whether to use reranker Returns: List of (index, score, text) tuples """ # Encode query query_embedding = self.retriever.encode( [query], device=self.device ) if isinstance(query_embedding, torch.Tensor): query_embedding = query_embedding.cpu().numpy() # Compute similarities similarities = np.dot(query_embedding, self.corpus_embeddings.T).squeeze() # type: ignore[arg-type] # Get top-k top_indices = np.argsort(similarities)[::-1][:top_k * 2 if rerank else top_k] results = [] for idx in top_indices: results.append((int(idx), float(similarities[idx]), self.corpus[idx])) # Rerank if requested if rerank and self.reranker: reranked_indices = self.reranker.rerank( query, [self.corpus[idx] for idx in top_indices], return_scores=False ) # Reorder results reranked_results = [] for i, rerank_idx in enumerate(reranked_indices[:top_k]): original_idx = top_indices[rerank_idx] # type: ignore[index] reranked_results.append(( int(original_idx), float(i + 1), # Rank as score self.corpus[original_idx] )) results = reranked_results return results[:top_k] def evaluate_query_interactively( self, query: str, relevant_indices: List[int], top_k: int = 10 ) -> Dict[str, float]: """ Evaluate a single query interactively Args: query: Query to evaluate relevant_indices: Ground truth relevant document indices top_k: Number of results to consider Returns: Metrics for this query """ # Get search results results = self.search(query, top_k=top_k, rerank=True) retrieved_indices = [r[0] for r in results] # Calculate metrics metrics = {} # Precision@k relevant_retrieved = sum(1 for idx in retrieved_indices if idx in relevant_indices) metrics[f'precision@{top_k}'] = relevant_retrieved / top_k # Recall@k if relevant_indices: metrics[f'recall@{top_k}'] = relevant_retrieved / len(relevant_indices) else: metrics[f'recall@{top_k}'] = 0.0 # MRR for i, idx in enumerate(retrieved_indices): if idx in relevant_indices: metrics['mrr'] = 1.0 / (i + 1) break else: metrics['mrr'] = 0.0 # Print results for inspection print(f"\nQuery: {query}") print(f"Relevant docs: {relevant_indices}") print(f"\nTop {top_k} results:") for i, (idx, score, text) in enumerate(results): relevance = "✓" if idx in relevant_indices else "✗" print(f"{i + 1}. [{relevance}] (idx={idx}, score={score:.3f}) {text[:100]}...") print(f"\nMetrics: {metrics}") return metrics