import argparse import pickle import logging from pathlib import Path from typing import Dict, List, Optional import numpy as np import torch from text_embedder import CLIPTextEmbedder # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Default paths DEFAULT_IMG_EMB_PATH = Path("embeddings/image_embeddings.pkl") DEFAULT_TEXT_MODEL = "openai/clip-vit-base-patch32" DEFAULT_TOP_K = 10 class TextToImageSearcher: """ Text-to-image search using CLIP embeddings """ def __init__( self, image_embeddings_path: Path, text_model_name: str = DEFAULT_TEXT_MODEL, embedding_dim: int = 4096 ): """ Initialize the Text-to-Image searcher Args: image_embeddings_path: Path to the pickle file containing image embeddings text_model_name: Name of the CLIP text model to use embedding_dim: Dimension of the embeddings (should match image embeddings) """ self.image_embeddings_path = Path(image_embeddings_path) self.embedding_dim = embedding_dim # Load image embeddings logger.info(f"Loading image embeddings from {self.image_embeddings_path}") self.image_data = self._load_image_embeddings() # Initialize text embedder logger.info(f"Initializing text embedder: {text_model_name}") self.text_embedder = CLIPTextEmbedder( model_name=text_model_name, output_dim=embedding_dim ) logger.info(f"TextToImageSearcher ready with {len(self.image_data)} images") def _load_image_embeddings(self) -> Dict[str, np.ndarray]: """Load image embeddings from pickle file""" try: with open(self.image_embeddings_path, "rb") as f: data = pickle.load(f) if not isinstance(data, dict): raise ValueError("Image embeddings file should contain a dictionary") # Verify embedding dimensions sample_key = next(iter(data)) sample_emb = data[sample_key] if sample_emb.shape[-1] != self.embedding_dim: raise ValueError(f"Expected {self.embedding_dim}D embeddings, got {sample_emb.shape[-1]}D") logger.info(f"Loaded {len(data)} image embeddings of dimension {sample_emb.shape[-1]}") return data except FileNotFoundError: raise FileNotFoundError(f"Image embeddings file not found: {self.image_embeddings_path}") except Exception as e: raise RuntimeError(f"Failed to load image embeddings: {e}") def search( self, query_text: str, top_k: int = DEFAULT_TOP_K, return_scores: bool = False ) -> List[Dict]: """ Search for images matching the query text Args: query_text: Text description to search for top_k: Number of top results to return return_scores: Whether to include similarity scores Returns: List of dictionaries with 'path' and optionally 'score' keys """ logger.info(f"Searching for: '{query_text}'") # Embed query text query_embedding = self.text_embedder.embed_single_text(query_text) # Prepare image data for comparison image_paths = list(self.image_data.keys()) image_embeddings = np.stack(list(self.image_data.values())) # (N, D) # Compute cosine similarities (both embeddings are normalized) similarities = np.dot(image_embeddings, query_embedding) # Get top-k indices top_k_indices = np.argsort(-similarities)[:top_k] # Prepare results results = [] for idx in top_k_indices: result = {'path': image_paths[idx]} if return_scores: result['score'] = float(similarities[idx]) results.append(result) return results def batch_search( self, query_texts: List[str], top_k: int = DEFAULT_TOP_K ) -> Dict[str, List[Dict]]: """ Search for multiple queries at once Args: query_texts: List of text descriptions top_k: Number of results per query Returns: Dictionary mapping query text to list of results """ logger.info(f"Batch searching {len(query_texts)} queries") results = {} for query in query_texts: results[query] = self.search(query, top_k=top_k, return_scores=True) return results def get_search_stats(self, query_text: str) -> Dict: """ Get statistics about search results Args: query_text: Query to analyze Returns: Dictionary with similarity statistics """ query_embedding = self.text_embedder.embed_single_text(query_text) image_embeddings = np.stack(list(self.image_data.values())) similarities = np.dot(image_embeddings, query_embedding) return { 'mean_similarity': float(np.mean(similarities)), 'max_similarity': float(np.max(similarities)), 'min_similarity': float(np.min(similarities)), 'std_similarity': float(np.std(similarities)), 'median_similarity': float(np.median(similarities)) } def get_memory_usage(self): """Get memory usage information""" return self.text_embedder.get_memory_usage() def clear_cache(self): """Clear text embedding cache""" self.text_embedder.clear_cache() def print_results(results: List[Dict], query: str, show_scores: bool = True): """Pretty print search results""" print(f"\nTop {len(results)} results for: '{query}'") print("-" * 80) for i, result in enumerate(results, 1): path = result['path'] if show_scores and 'score' in result: score = result['score'] print(f"{i:2d}. {path:<60} score={score:>7.4f}") else: print(f"{i:2d}. {path}") print("-" * 80) def main(): parser = argparse.ArgumentParser(description="Text-to-Image Search using CLIP embeddings") # Input/Output paths parser.add_argument( "--image-embeddings", type=str, default=str(DEFAULT_IMG_EMB_PATH), help="Path to image embeddings pickle file" ) # Query options parser.add_argument( "--query", type=str, help="Single text query to search for" ) parser.add_argument( "--queries-file", type=str, help="File with one query per line" ) parser.add_argument( "--interactive", action="store_true", help="Interactive mode for multiple queries" ) # Search parameters parser.add_argument( "--top-k", type=int, default=DEFAULT_TOP_K, help="Number of top results to return" ) parser.add_argument( "--text-model", type=str, default=DEFAULT_TEXT_MODEL, help="CLIP text model to use" ) parser.add_argument( "--embedding-dim", type=int, default=4096, help="Embedding dimension" ) # Output options parser.add_argument( "--output", type=str, help="Save results to JSON file" ) parser.add_argument( "--stats", action="store_true", help="Show similarity statistics" ) args = parser.parse_args() # Initialize searcher try: searcher = TextToImageSearcher( image_embeddings_path=args.image_embeddings, text_model_name=args.text_model, embedding_dim=args.embedding_dim ) except Exception as e: logger.error(f"Failed to initialize searcher: {e}") return # Show memory usage mem = searcher.get_memory_usage() logger.info(f"GPU Memory - Allocated: {mem['allocated']:.2f}GB, Reserved: {mem['reserved']:.2f}GB") # Handle different query modes if args.query: # Single query results = searcher.search(args.query, top_k=args.top_k, return_scores=True) print_results(results, args.query) if args.stats: stats = searcher.get_search_stats(args.query) print(f"\nSimilarity Statistics:") for key, value in stats.items(): print(f" {key}: {value:.4f}") elif args.queries_file: # Batch queries from file try: with open(args.queries_file, 'r', encoding='utf-8') as f: queries = [line.strip() for line in f if line.strip()] batch_results = searcher.batch_search(queries, top_k=args.top_k) for query, results in batch_results.items(): print_results(results, query) except Exception as e: logger.error(f"Failed to process queries file: {e}") return elif args.interactive: # Interactive mode print("\nInteractive Text-to-Image Search") print("Enter text queries (or 'quit' to exit):") print("-" * 40) while True: try: query = input("\nQuery: ").strip() if query.lower() in ['quit', 'exit', 'q']: break if not query: continue results = searcher.search(query, top_k=args.top_k, return_scores=True) print_results(results, query) if args.stats: stats = searcher.get_search_stats(query) print(f"\nStats: mean={stats['mean_similarity']:.3f}, " f"max={stats['max_similarity']:.3f}, " f"std={stats['std_similarity']:.3f}") except KeyboardInterrupt: print("\nExiting...") break except Exception as e: logger.error(f"Error processing query: {e}") else: # Demo mode with sample queries demo_queries = [ "a cat sleeping on a bed", ] print("Demo mode: Running sample queries...") batch_results = searcher.batch_search(demo_queries, top_k=5) for query, results in batch_results.items(): print_results(results, query) # Save results if requested if args.output and 'batch_results' in locals(): import json # Convert results to JSON-serializable format json_results = {} for query, results in batch_results.items(): json_results[query] = results # Results are already dict format try: with open(args.output, 'w', encoding='utf-8') as f: json.dump(json_results, f, indent=2, ensure_ascii=False) logger.info(f"Results saved to {args.output}") except Exception as e: logger.error(f"Failed to save results: {e}") if __name__ == "__main__": main()