import pickle import logging from pathlib import Path from typing import Dict, List, Tuple, Union, Optional import argparse import numpy as np from PIL import Image # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Default paths EMBED_DIR = Path(__file__).parent.parent / "embeddings" IMAGE_EMB_PATH = EMBED_DIR / "image_embeddings.pkl" class ImageSimilarityFinder: """ Find similar images using precomputed embeddings from image_embeddings.pkl. Uses cosine similarity to rank images by similarity. """ def __init__(self, embeddings_path: Union[str, Path] = IMAGE_EMB_PATH): """ Initialize the similarity finder with precomputed embeddings. Args: embeddings_path: Path to the pickle file containing image embeddings """ self.embeddings_path = Path(embeddings_path) self.embeddings = self._load_embeddings() self.image_paths = list(self.embeddings.keys()) self.embedding_matrix = np.array(list(self.embeddings.values())) logger.info(f"Loaded {len(self.embeddings)} image embeddings") def _load_embeddings(self) -> Dict[str, np.ndarray]: """Load embeddings from pickle file.""" try: with open(self.embeddings_path, "rb") as fp: embeddings = pickle.load(fp) return embeddings except FileNotFoundError: raise FileNotFoundError(f"Embeddings file not found: {self.embeddings_path}") except Exception as e: raise RuntimeError(f"Failed to load embeddings: {e}") def _cosine_similarity(self, vec1: np.ndarray, vec2: np.ndarray) -> float: """ Calculate cosine similarity between two vectors. Since embeddings are already L2 normalized, this is just dot product. """ return np.dot(vec1, vec2) def find_similar_by_path( self, query_image_path: Union[str, Path], k: int = 5, exclude_self: bool = True ) -> List[Tuple[str, float]]: """ Find k most similar images to a query image by its path. Args: query_image_path: Path to the query image k: Number of similar images to return exclude_self: Whether to exclude the query image itself from results Returns: List of tuples (image_path, similarity_score) sorted by similarity (highest first) """ query_path = str(Path(query_image_path)) if query_path not in self.embeddings: raise ValueError(f"Query image not found in embeddings: {query_path}") query_embedding = self.embeddings[query_path] return self._find_similar_by_embedding(query_embedding, k, exclude_path=query_path if exclude_self else None) def find_similar_by_embedding( self, query_embedding: np.ndarray, k: int = 5 ) -> List[Tuple[str, float]]: """ Find k most similar images to a query embedding. Args: query_embedding: Query image embedding vector k: Number of similar images to return Returns: List of tuples (image_path, similarity_score) sorted by similarity (highest first) """ return self._find_similar_by_embedding(query_embedding, k) def _find_similar_by_embedding( self, query_embedding: np.ndarray, k: int, exclude_path: Optional[str] = None ) -> List[Tuple[str, float]]: """ Internal method to find similar images by embedding. Args: query_embedding: Query embedding vector k: Number of results to return exclude_path: Path to exclude from results (typically the query image itself) Returns: List of tuples (image_path, similarity_score) sorted by similarity """ # Ensure query embedding is normalized (should already be from ImageEmbedder) query_embedding = query_embedding / np.linalg.norm(query_embedding) # Compute similarities with all embeddings similarities = np.dot(self.embedding_matrix, query_embedding) # Get indices sorted by similarity (descending) sorted_indices = np.argsort(similarities)[::-1] # Collect results, excluding the query image if specified results = [] for idx in sorted_indices: image_path = self.image_paths[idx] similarity = similarities[idx] # Skip if this is the path to exclude if exclude_path and image_path == exclude_path: continue results.append((image_path, float(similarity))) # Stop when we have k results if len(results) >= k: break return results def find_similar_images_batch( self, query_paths: List[Union[str, Path]], k: int = 5 ) -> Dict[str, List[Tuple[str, float]]]: """ Find similar images for multiple query images at once. Args: query_paths: List of paths to query images k: Number of similar images to return for each query Returns: Dictionary mapping query_path -> list of (similar_path, similarity_score) """ results = {} for query_path in query_paths: try: results[str(query_path)] = self.find_similar_by_path(query_path, k) except ValueError as e: logger.warning(f"Skipping {query_path}: {e}") results[str(query_path)] = [] return results def get_image_info(self) -> Dict[str, any]: """Get information about loaded embeddings.""" return { "total_images": len(self.embeddings), "embedding_dimension": self.embedding_matrix.shape[1] if len(self.embedding_matrix) > 0 else 0, "sample_paths": self.image_paths[:5] if self.image_paths else [] } def main(): parser = argparse.ArgumentParser( description="Find similar images using precomputed embeddings" ) parser.add_argument( "--query", type=str, required=True, help="Path to the query image" ) parser.add_argument( "--embeddings", type=str, default=str(IMAGE_EMB_PATH), help="Path to the embeddings pickle file" ) parser.add_argument( "--k", type=int, default=5, help="Number of similar images to return (default: 5)" ) parser.add_argument( "--include-self", action="store_true", help="Include the query image in results" ) parser.add_argument( "--info", action="store_true", help="Show information about loaded embeddings" ) args = parser.parse_args() # Initialize similarity finder try: finder = ImageSimilarityFinder(args.embeddings) except Exception as e: logger.error(f"Failed to initialize similarity finder: {e}") logger.error(f"Make sure you're providing the path to the embeddings pickle file, not an image file") logger.error(f"Expected: --embeddings /path/to/image_embeddings.pkl") return # Show info if requested if args.info: info = finder.get_image_info() print(f"Embeddings Info:") print(f" Total images: {info['total_images']}") print(f" Embedding dimension: {info['embedding_dimension']}") print(f" Sample paths: {info['sample_paths']}") print() # Find similar images try: similar_images = finder.find_similar_by_path( args.query, k=args.k, exclude_self=not args.include_self ) print(f"Top {len(similar_images)} similar images to '{args.query}':") print("-" * 80) for i, (image_path, similarity) in enumerate(similar_images, 1): print(f"{i:2d}. {Path(image_path).name}") print(f" Path: {image_path}") print(f" Similarity: {similarity:.4f}") print() except Exception as e: logger.error(f"Failed to find similar images: {e}") # Check if query image exists in embeddings if "not found in embeddings" in str(e): logger.error("Available images in embeddings:") info = finder.get_image_info() for i, path in enumerate(info['sample_paths']): logger.error(f" {i + 1}. {path}") if len(finder.image_paths) > 5: logger.error(f" ... and {len(finder.image_paths) - 5} more images") if __name__ == "__main__": main()