import argparse import pickle import logging from pathlib import Path from typing import Dict, List, Optional import numpy as np from PIL import Image from sklearn.metrics.pairwise import cosine_similarity from embedder import ImageEmbedder # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Default paths DEFAULT_IMG_EMB_PATH = Path("embeddings/image_embeddings.pkl") DEFAULT_VISION_MODEL = "openai/clip-vit-large-patch14-336" DEFAULT_TOP_K = 10 class ImageSimilaritySearcher: """ Image-to-image similarity search using unified embeddings dictionary """ def __init__( self, image_embeddings_path: Path, vision_model_name: str = DEFAULT_VISION_MODEL, embedding_dim: int = 4096, llava_ckpt_path: Optional[str] = None ): self.image_embeddings_path = Path(image_embeddings_path) self.embedding_dim = embedding_dim # Load image embeddings logger.info(f"Loading image embeddings from {self.image_embeddings_path}") self.image_data = self._load_image_embeddings() # Initialize image embedder for query images logger.info(f"Initializing image embedder: {vision_model_name}") self.image_embedder = ImageEmbedder( vision_model_name=vision_model_name, proj_out_dim=embedding_dim, llava_ckpt_path=llava_ckpt_path ) logger.info(f"ImageSimilaritySearcher ready with {len(self.image_data)} images") def _load_image_embeddings(self) -> Dict[str, np.ndarray]: """Load image embeddings from unified pickle file""" try: with open(self.image_embeddings_path, "rb") as f: data = pickle.load(f) if not isinstance(data, dict): raise ValueError("Image embeddings file should contain a dictionary") # Verify embedding dimensions if data: sample_key = next(iter(data)) sample_emb = data[sample_key] if sample_emb.shape[-1] != self.embedding_dim: raise ValueError(f"Expected {self.embedding_dim}D embeddings, got {sample_emb.shape[-1]}D") logger.info(f"Loaded {len(data)} image embeddings of dimension {sample_emb.shape[-1]}") else: logger.warning("Empty embeddings dictionary loaded") return data except FileNotFoundError: raise FileNotFoundError(f"Image embeddings file not found: {self.image_embeddings_path}") except Exception as e: raise RuntimeError(f"Failed to load image embeddings: {e}") def search_by_image_path( self, query_image_path: str, top_k: int = DEFAULT_TOP_K, exclude_self: bool = True, return_scores: bool = False ) -> List[Dict]: """ Search for similar images given a query image path Args: query_image_path: Path to query image top_k: Number of top results to return exclude_self: Whether to exclude the query image from results return_scores: Whether to include similarity scores Returns: List of dictionaries with 'path' and optionally 'score' keys """ logger.info(f"Searching for images similar to: {query_image_path}") # Load and embed query image try: query_img = Image.open(query_image_path).convert("RGB") query_embedding = self.image_embedder.image_to_embedding(query_img).cpu().numpy() except Exception as e: raise RuntimeError(f"Failed to process query image {query_image_path}: {e}") return self._search_by_embedding( query_embedding, query_image_path if exclude_self else None, top_k, return_scores ) def search_by_image( self, query_image: Image.Image, top_k: int = DEFAULT_TOP_K, return_scores: bool = False ) -> List[Dict]: """ Search for similar images given a PIL Image Args: query_image: PIL Image object top_k: Number of top results to return return_scores: Whether to include similarity scores Returns: List of dictionaries with 'path' and optionally 'score' keys """ logger.info("Searching for similar images using provided image") # Embed query image query_embedding = self.image_embedder.image_to_embedding(query_image).cpu().numpy() return self._search_by_embedding(query_embedding, None, top_k, return_scores) def _search_by_embedding( self, query_embedding: np.ndarray, exclude_path: Optional[str], top_k: int, return_scores: bool ) -> List[Dict]: """Internal method to search by embedding vector""" # Prepare database embeddings image_paths = list(self.image_data.keys()) image_embeddings = np.stack(list(self.image_data.values())) # (N, D) # Compute cosine similarities similarities = cosine_similarity( query_embedding.reshape(1, -1), image_embeddings ).flatten() # Handle exclusion valid_indices = list(range(len(image_paths))) if exclude_path: try: exclude_idx = image_paths.index(exclude_path) valid_indices.remove(exclude_idx) similarities = np.delete(similarities, exclude_idx) image_paths = [p for i, p in enumerate(image_paths) if i != exclude_idx] except ValueError: pass # Query path not in database # Get top-k indices top_k_indices = np.argsort(-similarities)[:top_k] # Prepare results results = [] for idx in top_k_indices: result = {'path': image_paths[idx]} if return_scores: result['score'] = float(similarities[idx]) results.append(result) return results def batch_search( self, query_image_paths: List[str], top_k: int = DEFAULT_TOP_K, exclude_self: bool = True ) -> Dict[str, List[Dict]]: """ Search for multiple query images at once Args: query_image_paths: List of paths to query images top_k: Number of results per query exclude_self: Whether to exclude query images from results Returns: Dictionary mapping query path to list of results """ logger.info(f"Batch searching {len(query_image_paths)} images") results = {} for query_path in query_image_paths: try: results[query_path] = self.search_by_image_path( query_path, top_k=top_k, exclude_self=exclude_self, return_scores=True ) except Exception as e: logger.error(f"Failed to search for {query_path}: {e}") results[query_path] = [] return results def get_similarity_stats(self, query_image_path: str) -> Dict: """ Get statistics about similarity scores for a query image Args: query_image_path: Path to query image Returns: Dictionary with similarity statistics """ try: query_img = Image.open(query_image_path).convert("RGB") query_embedding = self.image_embedder.image_to_embedding(query_img).cpu().numpy() except Exception as e: raise RuntimeError(f"Failed to process query image: {e}") image_embeddings = np.stack(list(self.image_data.values())) similarities = cosine_similarity( query_embedding.reshape(1, -1), image_embeddings ).flatten() return { 'mean_similarity': float(np.mean(similarities)), 'max_similarity': float(np.max(similarities)), 'min_similarity': float(np.min(similarities)), 'std_similarity': float(np.std(similarities)), 'median_similarity': float(np.median(similarities)) } def get_memory_usage(self): """Get memory usage information""" return self.image_embedder.get_memory_usage() def print_results(results: List[Dict], query: str, show_scores: bool = True): """Pretty print search results""" print(f"\nTop {len(results)} similar images to: {Path(query).name}") print(f"Query: {query}") print("-" * 100) for i, result in enumerate(results, 1): path = result['path'] if show_scores and 'score' in result: score = result['score'] print(f"{i:2d}. {path:<80} score={score:>7.4f}") else: print(f"{i:2d}. {path}") print("-" * 100) def main(): parser = argparse.ArgumentParser(description="Image-to-Image Similarity Search") # Input/Output paths parser.add_argument( "--image-embeddings", type=str, default=str(DEFAULT_IMG_EMB_PATH), help="Path to image embeddings pickle file" ) # Query options parser.add_argument( "--query-image", type=str, help="Path to query image" ) parser.add_argument( "--query-images-file", type=str, help="File with one image path per line" ) parser.add_argument( "--query-dir", type=str, help="Directory containing query images" ) # Model parameters parser.add_argument( "--vision-model", type=str, default=DEFAULT_VISION_MODEL, help="CLIP vision model to use" ) parser.add_argument( "--embedding-dim", type=int, default=4096, help="Embedding dimension" ) parser.add_argument( "--llava-ckpt", type=str, help="Path to LLaVA checkpoint for projection weights" ) # Search parameters parser.add_argument( "--top-k", type=int, default=DEFAULT_TOP_K, help="Number of top results to return" ) parser.add_argument( "--include-self", action="store_true", help="Include query image in results" ) # Output options parser.add_argument( "--output", type=str, help="Save results to JSON file" ) parser.add_argument( "--stats", action="store_true", help="Show similarity statistics" ) args = parser.parse_args() # Initialize searcher try: searcher = ImageSimilaritySearcher( image_embeddings_path=args.image_embeddings, vision_model_name=args.vision_model, embedding_dim=args.embedding_dim, llava_ckpt_path=args.llava_ckpt ) except Exception as e: logger.error(f"Failed to initialize searcher: {e}") return # Show memory usage mem = searcher.get_memory_usage() logger.info(f"GPU Memory - Allocated: {mem['allocated']:.2f}GB, Reserved: {mem['reserved']:.2f}GB") # Handle different query modes if args.query_image: # Single query image results = searcher.search_by_image_path( args.query_image, top_k=args.top_k, exclude_self=not args.include_self, return_scores=True ) print_results(results, args.query_image) if args.stats: stats = searcher.get_similarity_stats(args.query_image) print(f"\nSimilarity Statistics:") for key, value in stats.items(): print(f" {key}: {value:.4f}") elif args.query_images_file: # Batch queries from file try: with open(args.query_images_file, 'r', encoding='utf-8') as f: query_paths = [line.strip() for line in f if line.strip()] batch_results = searcher.batch_search( query_paths, top_k=args.top_k, exclude_self=not args.include_self ) for query_path, results in batch_results.items(): if results: # Only show if we got results print_results(results, query_path) except Exception as e: logger.error(f"Failed to process queries file: {e}") return elif args.query_dir: # Query all images in directory query_dir = Path(args.query_dir) if not query_dir.exists(): logger.error(f"Query directory does not exist: {query_dir}") return # Find images in directory image_exts = {'.jpg', '.jpeg', '.png', '.bmp', '.gif', '.webp', '.tiff', '.tif'} query_paths = [] for ext in image_exts: query_paths.extend(query_dir.glob(f"*{ext}")) query_paths.extend(query_dir.glob(f"*{ext.upper()}")) query_paths = [str(p) for p in sorted(query_paths)] if not query_paths: logger.error(f"No images found in {query_dir}") return logger.info(f"Found {len(query_paths)} images in {query_dir}") batch_results = searcher.batch_search( query_paths, top_k=args.top_k, exclude_self=not args.include_self ) for query_path, results in batch_results.items(): if results: # Only show if we got results print_results(results, query_path) else: print("Please specify a query image, query images file, or query directory") print("Use --help for more information") return # Save results if requested if args.output and 'batch_results' in locals(): import json try: with open(args.output, 'w', encoding='utf-8') as f: json.dump(batch_results, f, indent=2, ensure_ascii=False) logger.info(f"Results saved to {args.output}") except Exception as e: logger.error(f"Failed to save results: {e}") if __name__ == "__main__": main()