262 lines
8.6 KiB
Python
262 lines
8.6 KiB
Python
import pickle
|
|
import logging
|
|
from pathlib import Path
|
|
from typing import Dict, List, Tuple, Union, Optional
|
|
import argparse
|
|
|
|
import numpy as np
|
|
from PIL import Image
|
|
|
|
# Configure logging
|
|
logging.basicConfig(level=logging.INFO)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Default paths
|
|
EMBED_DIR = Path(__file__).parent.parent / "embeddings"
|
|
IMAGE_EMB_PATH = EMBED_DIR / "image_embeddings.pkl"
|
|
|
|
|
|
class ImageSimilarityFinder:
|
|
"""
|
|
Find similar images using precomputed embeddings from image_embeddings.pkl.
|
|
Uses cosine similarity to rank images by similarity.
|
|
"""
|
|
|
|
def __init__(self, embeddings_path: Union[str, Path] = IMAGE_EMB_PATH):
|
|
"""
|
|
Initialize the similarity finder with precomputed embeddings.
|
|
|
|
Args:
|
|
embeddings_path: Path to the pickle file containing image embeddings
|
|
"""
|
|
self.embeddings_path = Path(embeddings_path)
|
|
self.embeddings = self._load_embeddings()
|
|
self.image_paths = list(self.embeddings.keys())
|
|
self.embedding_matrix = np.array(list(self.embeddings.values()))
|
|
|
|
logger.info(f"Loaded {len(self.embeddings)} image embeddings")
|
|
|
|
def _load_embeddings(self) -> Dict[str, np.ndarray]:
|
|
"""Load embeddings from pickle file."""
|
|
try:
|
|
with open(self.embeddings_path, "rb") as fp:
|
|
embeddings = pickle.load(fp)
|
|
return embeddings
|
|
except FileNotFoundError:
|
|
raise FileNotFoundError(f"Embeddings file not found: {self.embeddings_path}")
|
|
except Exception as e:
|
|
raise RuntimeError(f"Failed to load embeddings: {e}")
|
|
|
|
def _cosine_similarity(self, vec1: np.ndarray, vec2: np.ndarray) -> float:
|
|
"""
|
|
Calculate cosine similarity between two vectors.
|
|
Since embeddings are already L2 normalized, this is just dot product.
|
|
"""
|
|
return np.dot(vec1, vec2)
|
|
|
|
def find_similar_by_path(
|
|
self,
|
|
query_image_path: Union[str, Path],
|
|
k: int = 5,
|
|
exclude_self: bool = True
|
|
) -> List[Tuple[str, float]]:
|
|
"""
|
|
Find k most similar images to a query image by its path.
|
|
|
|
Args:
|
|
query_image_path: Path to the query image
|
|
k: Number of similar images to return
|
|
exclude_self: Whether to exclude the query image itself from results
|
|
|
|
Returns:
|
|
List of tuples (image_path, similarity_score) sorted by similarity (highest first)
|
|
"""
|
|
query_path = str(Path(query_image_path))
|
|
|
|
if query_path not in self.embeddings:
|
|
raise ValueError(f"Query image not found in embeddings: {query_path}")
|
|
|
|
query_embedding = self.embeddings[query_path]
|
|
return self._find_similar_by_embedding(query_embedding, k, exclude_path=query_path if exclude_self else None)
|
|
|
|
def find_similar_by_embedding(
|
|
self,
|
|
query_embedding: np.ndarray,
|
|
k: int = 5
|
|
) -> List[Tuple[str, float]]:
|
|
"""
|
|
Find k most similar images to a query embedding.
|
|
|
|
Args:
|
|
query_embedding: Query image embedding vector
|
|
k: Number of similar images to return
|
|
|
|
Returns:
|
|
List of tuples (image_path, similarity_score) sorted by similarity (highest first)
|
|
"""
|
|
return self._find_similar_by_embedding(query_embedding, k)
|
|
|
|
def _find_similar_by_embedding(
|
|
self,
|
|
query_embedding: np.ndarray,
|
|
k: int,
|
|
exclude_path: Optional[str] = None
|
|
) -> List[Tuple[str, float]]:
|
|
"""
|
|
Internal method to find similar images by embedding.
|
|
|
|
Args:
|
|
query_embedding: Query embedding vector
|
|
k: Number of results to return
|
|
exclude_path: Path to exclude from results (typically the query image itself)
|
|
|
|
Returns:
|
|
List of tuples (image_path, similarity_score) sorted by similarity
|
|
"""
|
|
# Ensure query embedding is normalized (should already be from ImageEmbedder)
|
|
query_embedding = query_embedding / np.linalg.norm(query_embedding)
|
|
|
|
# Compute similarities with all embeddings
|
|
similarities = np.dot(self.embedding_matrix, query_embedding)
|
|
|
|
# Get indices sorted by similarity (descending)
|
|
sorted_indices = np.argsort(similarities)[::-1]
|
|
|
|
# Collect results, excluding the query image if specified
|
|
results = []
|
|
for idx in sorted_indices:
|
|
image_path = self.image_paths[idx]
|
|
similarity = similarities[idx]
|
|
|
|
# Skip if this is the path to exclude
|
|
if exclude_path and image_path == exclude_path:
|
|
continue
|
|
|
|
results.append((image_path, float(similarity)))
|
|
|
|
# Stop when we have k results
|
|
if len(results) >= k:
|
|
break
|
|
|
|
return results
|
|
|
|
def find_similar_images_batch(
|
|
self,
|
|
query_paths: List[Union[str, Path]],
|
|
k: int = 5
|
|
) -> Dict[str, List[Tuple[str, float]]]:
|
|
"""
|
|
Find similar images for multiple query images at once.
|
|
|
|
Args:
|
|
query_paths: List of paths to query images
|
|
k: Number of similar images to return for each query
|
|
|
|
Returns:
|
|
Dictionary mapping query_path -> list of (similar_path, similarity_score)
|
|
"""
|
|
results = {}
|
|
for query_path in query_paths:
|
|
try:
|
|
results[str(query_path)] = self.find_similar_by_path(query_path, k)
|
|
except ValueError as e:
|
|
logger.warning(f"Skipping {query_path}: {e}")
|
|
results[str(query_path)] = []
|
|
|
|
return results
|
|
|
|
def get_image_info(self) -> Dict[str, any]:
|
|
"""Get information about loaded embeddings."""
|
|
return {
|
|
"total_images": len(self.embeddings),
|
|
"embedding_dimension": self.embedding_matrix.shape[1] if len(self.embedding_matrix) > 0 else 0,
|
|
"sample_paths": self.image_paths[:5] if self.image_paths else []
|
|
}
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(
|
|
description="Find similar images using precomputed embeddings"
|
|
)
|
|
parser.add_argument(
|
|
"--query",
|
|
type=str,
|
|
required=True,
|
|
help="Path to the query image"
|
|
)
|
|
parser.add_argument(
|
|
"--embeddings",
|
|
type=str,
|
|
default=str(IMAGE_EMB_PATH),
|
|
help="Path to the embeddings pickle file"
|
|
)
|
|
parser.add_argument(
|
|
"--k",
|
|
type=int,
|
|
default=5,
|
|
help="Number of similar images to return (default: 5)"
|
|
)
|
|
parser.add_argument(
|
|
"--include-self",
|
|
action="store_true",
|
|
help="Include the query image in results"
|
|
)
|
|
parser.add_argument(
|
|
"--info",
|
|
action="store_true",
|
|
help="Show information about loaded embeddings"
|
|
)
|
|
|
|
args = parser.parse_args()
|
|
|
|
# Initialize similarity finder
|
|
try:
|
|
finder = ImageSimilarityFinder(args.embeddings)
|
|
except Exception as e:
|
|
logger.error(f"Failed to initialize similarity finder: {e}")
|
|
logger.error(f"Make sure you're providing the path to the embeddings pickle file, not an image file")
|
|
logger.error(f"Expected: --embeddings /path/to/image_embeddings.pkl")
|
|
return
|
|
|
|
# Show info if requested
|
|
if args.info:
|
|
info = finder.get_image_info()
|
|
print(f"Embeddings Info:")
|
|
print(f" Total images: {info['total_images']}")
|
|
print(f" Embedding dimension: {info['embedding_dimension']}")
|
|
print(f" Sample paths: {info['sample_paths']}")
|
|
print()
|
|
|
|
# Find similar images
|
|
try:
|
|
similar_images = finder.find_similar_by_path(
|
|
args.query,
|
|
k=args.k,
|
|
exclude_self=not args.include_self
|
|
)
|
|
|
|
print(f"Top {len(similar_images)} similar images to '{args.query}':")
|
|
print("-" * 80)
|
|
|
|
for i, (image_path, similarity) in enumerate(similar_images, 1):
|
|
print(f"{i:2d}. {Path(image_path).name}")
|
|
print(f" Path: {image_path}")
|
|
print(f" Similarity: {similarity:.4f}")
|
|
print()
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to find similar images: {e}")
|
|
|
|
# Check if query image exists in embeddings
|
|
if "not found in embeddings" in str(e):
|
|
logger.error("Available images in embeddings:")
|
|
info = finder.get_image_info()
|
|
for i, path in enumerate(info['sample_paths']):
|
|
logger.error(f" {i + 1}. {path}")
|
|
if len(finder.image_paths) > 5:
|
|
logger.error(f" ... and {len(finder.image_paths) - 5} more images")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|