Img2Vec/find_similar_img.py

445 lines
14 KiB
Python

import argparse
import pickle
import logging
from pathlib import Path
from typing import Dict, List, Optional
import numpy as np
from PIL import Image
from sklearn.metrics.pairwise import cosine_similarity
from embedder import ImageEmbedder
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Default paths
DEFAULT_IMG_EMB_PATH = Path("embeddings/image_embeddings.pkl")
DEFAULT_VISION_MODEL = "openai/clip-vit-large-patch14-336"
DEFAULT_TOP_K = 10
class ImageSimilaritySearcher:
"""
Image-to-image similarity search using unified embeddings dictionary
"""
def __init__(
self,
image_embeddings_path: Path,
vision_model_name: str = DEFAULT_VISION_MODEL,
embedding_dim: int = 4096,
llava_ckpt_path: Optional[str] = None
):
self.image_embeddings_path = Path(image_embeddings_path)
self.embedding_dim = embedding_dim
# Load image embeddings
logger.info(f"Loading image embeddings from {self.image_embeddings_path}")
self.image_data = self._load_image_embeddings()
# Initialize image embedder for query images
logger.info(f"Initializing image embedder: {vision_model_name}")
self.image_embedder = ImageEmbedder(
vision_model_name=vision_model_name,
proj_out_dim=embedding_dim,
llava_ckpt_path=llava_ckpt_path
)
logger.info(f"ImageSimilaritySearcher ready with {len(self.image_data)} images")
def _load_image_embeddings(self) -> Dict[str, np.ndarray]:
"""Load image embeddings from unified pickle file"""
try:
with open(self.image_embeddings_path, "rb") as f:
data = pickle.load(f)
if not isinstance(data, dict):
raise ValueError("Image embeddings file should contain a dictionary")
# Verify embedding dimensions
if data:
sample_key = next(iter(data))
sample_emb = data[sample_key]
if sample_emb.shape[-1] != self.embedding_dim:
raise ValueError(f"Expected {self.embedding_dim}D embeddings, got {sample_emb.shape[-1]}D")
logger.info(f"Loaded {len(data)} image embeddings of dimension {sample_emb.shape[-1]}")
else:
logger.warning("Empty embeddings dictionary loaded")
return data
except FileNotFoundError:
raise FileNotFoundError(f"Image embeddings file not found: {self.image_embeddings_path}")
except Exception as e:
raise RuntimeError(f"Failed to load image embeddings: {e}")
def search_by_image_path(
self,
query_image_path: str,
top_k: int = DEFAULT_TOP_K,
exclude_self: bool = True,
return_scores: bool = False
) -> List[Dict]:
"""
Search for similar images given a query image path
Args:
query_image_path: Path to query image
top_k: Number of top results to return
exclude_self: Whether to exclude the query image from results
return_scores: Whether to include similarity scores
Returns:
List of dictionaries with 'path' and optionally 'score' keys
"""
logger.info(f"Searching for images similar to: {query_image_path}")
# Load and embed query image
try:
query_img = Image.open(query_image_path).convert("RGB")
query_embedding = self.image_embedder.image_to_embedding(query_img).cpu().numpy()
except Exception as e:
raise RuntimeError(f"Failed to process query image {query_image_path}: {e}")
return self._search_by_embedding(
query_embedding,
query_image_path if exclude_self else None,
top_k,
return_scores
)
def search_by_image(
self,
query_image: Image.Image,
top_k: int = DEFAULT_TOP_K,
return_scores: bool = False
) -> List[Dict]:
"""
Search for similar images given a PIL Image
Args:
query_image: PIL Image object
top_k: Number of top results to return
return_scores: Whether to include similarity scores
Returns:
List of dictionaries with 'path' and optionally 'score' keys
"""
logger.info("Searching for similar images using provided image")
# Embed query image
query_embedding = self.image_embedder.image_to_embedding(query_image).cpu().numpy()
return self._search_by_embedding(query_embedding, None, top_k, return_scores)
def _search_by_embedding(
self,
query_embedding: np.ndarray,
exclude_path: Optional[str],
top_k: int,
return_scores: bool
) -> List[Dict]:
"""Internal method to search by embedding vector"""
# Prepare database embeddings
image_paths = list(self.image_data.keys())
image_embeddings = np.stack(list(self.image_data.values())) # (N, D)
# Compute cosine similarities
similarities = cosine_similarity(
query_embedding.reshape(1, -1),
image_embeddings
).flatten()
# Handle exclusion
valid_indices = list(range(len(image_paths)))
if exclude_path:
try:
exclude_idx = image_paths.index(exclude_path)
valid_indices.remove(exclude_idx)
similarities = np.delete(similarities, exclude_idx)
image_paths = [p for i, p in enumerate(image_paths) if i != exclude_idx]
except ValueError:
pass # Query path not in database
# Get top-k indices
top_k_indices = np.argsort(-similarities)[:top_k]
# Prepare results
results = []
for idx in top_k_indices:
result = {'path': image_paths[idx]}
if return_scores:
result['score'] = float(similarities[idx])
results.append(result)
return results
def batch_search(
self,
query_image_paths: List[str],
top_k: int = DEFAULT_TOP_K,
exclude_self: bool = True
) -> Dict[str, List[Dict]]:
"""
Search for multiple query images at once
Args:
query_image_paths: List of paths to query images
top_k: Number of results per query
exclude_self: Whether to exclude query images from results
Returns:
Dictionary mapping query path to list of results
"""
logger.info(f"Batch searching {len(query_image_paths)} images")
results = {}
for query_path in query_image_paths:
try:
results[query_path] = self.search_by_image_path(
query_path,
top_k=top_k,
exclude_self=exclude_self,
return_scores=True
)
except Exception as e:
logger.error(f"Failed to search for {query_path}: {e}")
results[query_path] = []
return results
def get_similarity_stats(self, query_image_path: str) -> Dict:
"""
Get statistics about similarity scores for a query image
Args:
query_image_path: Path to query image
Returns:
Dictionary with similarity statistics
"""
try:
query_img = Image.open(query_image_path).convert("RGB")
query_embedding = self.image_embedder.image_to_embedding(query_img).cpu().numpy()
except Exception as e:
raise RuntimeError(f"Failed to process query image: {e}")
image_embeddings = np.stack(list(self.image_data.values()))
similarities = cosine_similarity(
query_embedding.reshape(1, -1),
image_embeddings
).flatten()
return {
'mean_similarity': float(np.mean(similarities)),
'max_similarity': float(np.max(similarities)),
'min_similarity': float(np.min(similarities)),
'std_similarity': float(np.std(similarities)),
'median_similarity': float(np.median(similarities))
}
def get_memory_usage(self):
"""Get memory usage information"""
return self.image_embedder.get_memory_usage()
def print_results(results: List[Dict], query: str, show_scores: bool = True):
"""Pretty print search results"""
print(f"\nTop {len(results)} similar images to: {Path(query).name}")
print(f"Query: {query}")
print("-" * 100)
for i, result in enumerate(results, 1):
path = result['path']
if show_scores and 'score' in result:
score = result['score']
print(f"{i:2d}. {path:<80} score={score:>7.4f}")
else:
print(f"{i:2d}. {path}")
print("-" * 100)
def main():
parser = argparse.ArgumentParser(description="Image-to-Image Similarity Search")
# Input/Output paths
parser.add_argument(
"--image-embeddings",
type=str,
default=str(DEFAULT_IMG_EMB_PATH),
help="Path to image embeddings pickle file"
)
# Query options
parser.add_argument(
"--query-image",
type=str,
help="Path to query image"
)
parser.add_argument(
"--query-images-file",
type=str,
help="File with one image path per line"
)
parser.add_argument(
"--query-dir",
type=str,
help="Directory containing query images"
)
# Model parameters
parser.add_argument(
"--vision-model",
type=str,
default=DEFAULT_VISION_MODEL,
help="CLIP vision model to use"
)
parser.add_argument(
"--embedding-dim",
type=int,
default=4096,
help="Embedding dimension"
)
parser.add_argument(
"--llava-ckpt",
type=str,
help="Path to LLaVA checkpoint for projection weights"
)
# Search parameters
parser.add_argument(
"--top-k",
type=int,
default=DEFAULT_TOP_K,
help="Number of top results to return"
)
parser.add_argument(
"--include-self",
action="store_true",
help="Include query image in results"
)
# Output options
parser.add_argument(
"--output",
type=str,
help="Save results to JSON file"
)
parser.add_argument(
"--stats",
action="store_true",
help="Show similarity statistics"
)
args = parser.parse_args()
# Initialize searcher
try:
searcher = ImageSimilaritySearcher(
image_embeddings_path=args.image_embeddings,
vision_model_name=args.vision_model,
embedding_dim=args.embedding_dim,
llava_ckpt_path=args.llava_ckpt
)
except Exception as e:
logger.error(f"Failed to initialize searcher: {e}")
return
# Show memory usage
mem = searcher.get_memory_usage()
logger.info(f"GPU Memory - Allocated: {mem['allocated']:.2f}GB, Reserved: {mem['reserved']:.2f}GB")
# Handle different query modes
if args.query_image:
# Single query image
results = searcher.search_by_image_path(
args.query_image,
top_k=args.top_k,
exclude_self=not args.include_self,
return_scores=True
)
print_results(results, args.query_image)
if args.stats:
stats = searcher.get_similarity_stats(args.query_image)
print(f"\nSimilarity Statistics:")
for key, value in stats.items():
print(f" {key}: {value:.4f}")
elif args.query_images_file:
# Batch queries from file
try:
with open(args.query_images_file, 'r', encoding='utf-8') as f:
query_paths = [line.strip() for line in f if line.strip()]
batch_results = searcher.batch_search(
query_paths,
top_k=args.top_k,
exclude_self=not args.include_self
)
for query_path, results in batch_results.items():
if results: # Only show if we got results
print_results(results, query_path)
except Exception as e:
logger.error(f"Failed to process queries file: {e}")
return
elif args.query_dir:
# Query all images in directory
query_dir = Path(args.query_dir)
if not query_dir.exists():
logger.error(f"Query directory does not exist: {query_dir}")
return
# Find images in directory
image_exts = {'.jpg', '.jpeg', '.png', '.bmp', '.gif', '.webp', '.tiff', '.tif'}
query_paths = []
for ext in image_exts:
query_paths.extend(query_dir.glob(f"*{ext}"))
query_paths.extend(query_dir.glob(f"*{ext.upper()}"))
query_paths = [str(p) for p in sorted(query_paths)]
if not query_paths:
logger.error(f"No images found in {query_dir}")
return
logger.info(f"Found {len(query_paths)} images in {query_dir}")
batch_results = searcher.batch_search(
query_paths,
top_k=args.top_k,
exclude_self=not args.include_self
)
for query_path, results in batch_results.items():
if results: # Only show if we got results
print_results(results, query_path)
else:
print("Please specify a query image, query images file, or query directory")
print("Use --help for more information")
return
# Save results if requested
if args.output and 'batch_results' in locals():
import json
try:
with open(args.output, 'w', encoding='utf-8') as f:
json.dump(batch_results, f, indent=2, ensure_ascii=False)
logger.info(f"Results saved to {args.output}")
except Exception as e:
logger.error(f"Failed to save results: {e}")
if __name__ == "__main__":
main()