text2img/examples/similarity_finder.py
2025-06-18 16:07:02 +08:00

262 lines
8.6 KiB
Python

import pickle
import logging
from pathlib import Path
from typing import Dict, List, Tuple, Union, Optional
import argparse
import numpy as np
from PIL import Image
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Default paths
EMBED_DIR = Path(__file__).parent.parent / "embeddings"
IMAGE_EMB_PATH = EMBED_DIR / "image_embeddings.pkl"
class ImageSimilarityFinder:
"""
Find similar images using precomputed embeddings from image_embeddings.pkl.
Uses cosine similarity to rank images by similarity.
"""
def __init__(self, embeddings_path: Union[str, Path] = IMAGE_EMB_PATH):
"""
Initialize the similarity finder with precomputed embeddings.
Args:
embeddings_path: Path to the pickle file containing image embeddings
"""
self.embeddings_path = Path(embeddings_path)
self.embeddings = self._load_embeddings()
self.image_paths = list(self.embeddings.keys())
self.embedding_matrix = np.array(list(self.embeddings.values()))
logger.info(f"Loaded {len(self.embeddings)} image embeddings")
def _load_embeddings(self) -> Dict[str, np.ndarray]:
"""Load embeddings from pickle file."""
try:
with open(self.embeddings_path, "rb") as fp:
embeddings = pickle.load(fp)
return embeddings
except FileNotFoundError:
raise FileNotFoundError(f"Embeddings file not found: {self.embeddings_path}")
except Exception as e:
raise RuntimeError(f"Failed to load embeddings: {e}")
def _cosine_similarity(self, vec1: np.ndarray, vec2: np.ndarray) -> float:
"""
Calculate cosine similarity between two vectors.
Since embeddings are already L2 normalized, this is just dot product.
"""
return np.dot(vec1, vec2)
def find_similar_by_path(
self,
query_image_path: Union[str, Path],
k: int = 5,
exclude_self: bool = True
) -> List[Tuple[str, float]]:
"""
Find k most similar images to a query image by its path.
Args:
query_image_path: Path to the query image
k: Number of similar images to return
exclude_self: Whether to exclude the query image itself from results
Returns:
List of tuples (image_path, similarity_score) sorted by similarity (highest first)
"""
query_path = str(Path(query_image_path))
if query_path not in self.embeddings:
raise ValueError(f"Query image not found in embeddings: {query_path}")
query_embedding = self.embeddings[query_path]
return self._find_similar_by_embedding(query_embedding, k, exclude_path=query_path if exclude_self else None)
def find_similar_by_embedding(
self,
query_embedding: np.ndarray,
k: int = 5
) -> List[Tuple[str, float]]:
"""
Find k most similar images to a query embedding.
Args:
query_embedding: Query image embedding vector
k: Number of similar images to return
Returns:
List of tuples (image_path, similarity_score) sorted by similarity (highest first)
"""
return self._find_similar_by_embedding(query_embedding, k)
def _find_similar_by_embedding(
self,
query_embedding: np.ndarray,
k: int,
exclude_path: Optional[str] = None
) -> List[Tuple[str, float]]:
"""
Internal method to find similar images by embedding.
Args:
query_embedding: Query embedding vector
k: Number of results to return
exclude_path: Path to exclude from results (typically the query image itself)
Returns:
List of tuples (image_path, similarity_score) sorted by similarity
"""
# Ensure query embedding is normalized (should already be from ImageEmbedder)
query_embedding = query_embedding / np.linalg.norm(query_embedding)
# Compute similarities with all embeddings
similarities = np.dot(self.embedding_matrix, query_embedding)
# Get indices sorted by similarity (descending)
sorted_indices = np.argsort(similarities)[::-1]
# Collect results, excluding the query image if specified
results = []
for idx in sorted_indices:
image_path = self.image_paths[idx]
similarity = similarities[idx]
# Skip if this is the path to exclude
if exclude_path and image_path == exclude_path:
continue
results.append((image_path, float(similarity)))
# Stop when we have k results
if len(results) >= k:
break
return results
def find_similar_images_batch(
self,
query_paths: List[Union[str, Path]],
k: int = 5
) -> Dict[str, List[Tuple[str, float]]]:
"""
Find similar images for multiple query images at once.
Args:
query_paths: List of paths to query images
k: Number of similar images to return for each query
Returns:
Dictionary mapping query_path -> list of (similar_path, similarity_score)
"""
results = {}
for query_path in query_paths:
try:
results[str(query_path)] = self.find_similar_by_path(query_path, k)
except ValueError as e:
logger.warning(f"Skipping {query_path}: {e}")
results[str(query_path)] = []
return results
def get_image_info(self) -> Dict[str, any]:
"""Get information about loaded embeddings."""
return {
"total_images": len(self.embeddings),
"embedding_dimension": self.embedding_matrix.shape[1] if len(self.embedding_matrix) > 0 else 0,
"sample_paths": self.image_paths[:5] if self.image_paths else []
}
def main():
parser = argparse.ArgumentParser(
description="Find similar images using precomputed embeddings"
)
parser.add_argument(
"--query",
type=str,
required=True,
help="Path to the query image"
)
parser.add_argument(
"--embeddings",
type=str,
default=str(IMAGE_EMB_PATH),
help="Path to the embeddings pickle file"
)
parser.add_argument(
"--k",
type=int,
default=5,
help="Number of similar images to return (default: 5)"
)
parser.add_argument(
"--include-self",
action="store_true",
help="Include the query image in results"
)
parser.add_argument(
"--info",
action="store_true",
help="Show information about loaded embeddings"
)
args = parser.parse_args()
# Initialize similarity finder
try:
finder = ImageSimilarityFinder(args.embeddings)
except Exception as e:
logger.error(f"Failed to initialize similarity finder: {e}")
logger.error(f"Make sure you're providing the path to the embeddings pickle file, not an image file")
logger.error(f"Expected: --embeddings /path/to/image_embeddings.pkl")
return
# Show info if requested
if args.info:
info = finder.get_image_info()
print(f"Embeddings Info:")
print(f" Total images: {info['total_images']}")
print(f" Embedding dimension: {info['embedding_dimension']}")
print(f" Sample paths: {info['sample_paths']}")
print()
# Find similar images
try:
similar_images = finder.find_similar_by_path(
args.query,
k=args.k,
exclude_self=not args.include_self
)
print(f"Top {len(similar_images)} similar images to '{args.query}':")
print("-" * 80)
for i, (image_path, similarity) in enumerate(similar_images, 1):
print(f"{i:2d}. {Path(image_path).name}")
print(f" Path: {image_path}")
print(f" Similarity: {similarity:.4f}")
print()
except Exception as e:
logger.error(f"Failed to find similar images: {e}")
# Check if query image exists in embeddings
if "not found in embeddings" in str(e):
logger.error("Available images in embeddings:")
info = finder.get_image_info()
for i, path in enumerate(info['sample_paths']):
logger.error(f" {i + 1}. {path}")
if len(finder.image_paths) > 5:
logger.error(f" ... and {len(finder.image_paths) - 5} more images")
if __name__ == "__main__":
main()