445 lines
14 KiB
Python
445 lines
14 KiB
Python
import argparse
|
|
import pickle
|
|
import logging
|
|
from pathlib import Path
|
|
from typing import Dict, List, Optional
|
|
|
|
import numpy as np
|
|
from PIL import Image
|
|
from sklearn.metrics.pairwise import cosine_similarity
|
|
|
|
from embedder import ImageEmbedder
|
|
|
|
# Configure logging
|
|
logging.basicConfig(level=logging.INFO)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Default paths
|
|
DEFAULT_IMG_EMB_PATH = Path("embeddings/image_embeddings.pkl")
|
|
DEFAULT_VISION_MODEL = "openai/clip-vit-large-patch14-336"
|
|
DEFAULT_TOP_K = 10
|
|
|
|
|
|
class ImageSimilaritySearcher:
|
|
"""
|
|
Image-to-image similarity search using unified embeddings dictionary
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
image_embeddings_path: Path,
|
|
vision_model_name: str = DEFAULT_VISION_MODEL,
|
|
embedding_dim: int = 4096,
|
|
llava_ckpt_path: Optional[str] = None
|
|
):
|
|
self.image_embeddings_path = Path(image_embeddings_path)
|
|
self.embedding_dim = embedding_dim
|
|
|
|
# Load image embeddings
|
|
logger.info(f"Loading image embeddings from {self.image_embeddings_path}")
|
|
self.image_data = self._load_image_embeddings()
|
|
|
|
# Initialize image embedder for query images
|
|
logger.info(f"Initializing image embedder: {vision_model_name}")
|
|
self.image_embedder = ImageEmbedder(
|
|
vision_model_name=vision_model_name,
|
|
proj_out_dim=embedding_dim,
|
|
llava_ckpt_path=llava_ckpt_path
|
|
)
|
|
|
|
logger.info(f"ImageSimilaritySearcher ready with {len(self.image_data)} images")
|
|
|
|
def _load_image_embeddings(self) -> Dict[str, np.ndarray]:
|
|
"""Load image embeddings from unified pickle file"""
|
|
try:
|
|
with open(self.image_embeddings_path, "rb") as f:
|
|
data = pickle.load(f)
|
|
|
|
if not isinstance(data, dict):
|
|
raise ValueError("Image embeddings file should contain a dictionary")
|
|
|
|
# Verify embedding dimensions
|
|
if data:
|
|
sample_key = next(iter(data))
|
|
sample_emb = data[sample_key]
|
|
if sample_emb.shape[-1] != self.embedding_dim:
|
|
raise ValueError(f"Expected {self.embedding_dim}D embeddings, got {sample_emb.shape[-1]}D")
|
|
|
|
logger.info(f"Loaded {len(data)} image embeddings of dimension {sample_emb.shape[-1]}")
|
|
else:
|
|
logger.warning("Empty embeddings dictionary loaded")
|
|
|
|
return data
|
|
|
|
except FileNotFoundError:
|
|
raise FileNotFoundError(f"Image embeddings file not found: {self.image_embeddings_path}")
|
|
except Exception as e:
|
|
raise RuntimeError(f"Failed to load image embeddings: {e}")
|
|
|
|
def search_by_image_path(
|
|
self,
|
|
query_image_path: str,
|
|
top_k: int = DEFAULT_TOP_K,
|
|
exclude_self: bool = True,
|
|
return_scores: bool = False
|
|
) -> List[Dict]:
|
|
"""
|
|
Search for similar images given a query image path
|
|
|
|
Args:
|
|
query_image_path: Path to query image
|
|
top_k: Number of top results to return
|
|
exclude_self: Whether to exclude the query image from results
|
|
return_scores: Whether to include similarity scores
|
|
|
|
Returns:
|
|
List of dictionaries with 'path' and optionally 'score' keys
|
|
"""
|
|
logger.info(f"Searching for images similar to: {query_image_path}")
|
|
|
|
# Load and embed query image
|
|
try:
|
|
query_img = Image.open(query_image_path).convert("RGB")
|
|
query_embedding = self.image_embedder.image_to_embedding(query_img).cpu().numpy()
|
|
except Exception as e:
|
|
raise RuntimeError(f"Failed to process query image {query_image_path}: {e}")
|
|
|
|
return self._search_by_embedding(
|
|
query_embedding,
|
|
query_image_path if exclude_self else None,
|
|
top_k,
|
|
return_scores
|
|
)
|
|
|
|
def search_by_image(
|
|
self,
|
|
query_image: Image.Image,
|
|
top_k: int = DEFAULT_TOP_K,
|
|
return_scores: bool = False
|
|
) -> List[Dict]:
|
|
"""
|
|
Search for similar images given a PIL Image
|
|
|
|
Args:
|
|
query_image: PIL Image object
|
|
top_k: Number of top results to return
|
|
return_scores: Whether to include similarity scores
|
|
|
|
Returns:
|
|
List of dictionaries with 'path' and optionally 'score' keys
|
|
"""
|
|
logger.info("Searching for similar images using provided image")
|
|
|
|
# Embed query image
|
|
query_embedding = self.image_embedder.image_to_embedding(query_image).cpu().numpy()
|
|
|
|
return self._search_by_embedding(query_embedding, None, top_k, return_scores)
|
|
|
|
def _search_by_embedding(
|
|
self,
|
|
query_embedding: np.ndarray,
|
|
exclude_path: Optional[str],
|
|
top_k: int,
|
|
return_scores: bool
|
|
) -> List[Dict]:
|
|
"""Internal method to search by embedding vector"""
|
|
|
|
# Prepare database embeddings
|
|
image_paths = list(self.image_data.keys())
|
|
image_embeddings = np.stack(list(self.image_data.values())) # (N, D)
|
|
|
|
# Compute cosine similarities
|
|
similarities = cosine_similarity(
|
|
query_embedding.reshape(1, -1),
|
|
image_embeddings
|
|
).flatten()
|
|
|
|
# Handle exclusion
|
|
valid_indices = list(range(len(image_paths)))
|
|
if exclude_path:
|
|
try:
|
|
exclude_idx = image_paths.index(exclude_path)
|
|
valid_indices.remove(exclude_idx)
|
|
similarities = np.delete(similarities, exclude_idx)
|
|
image_paths = [p for i, p in enumerate(image_paths) if i != exclude_idx]
|
|
except ValueError:
|
|
pass # Query path not in database
|
|
|
|
# Get top-k indices
|
|
top_k_indices = np.argsort(-similarities)[:top_k]
|
|
|
|
# Prepare results
|
|
results = []
|
|
for idx in top_k_indices:
|
|
result = {'path': image_paths[idx]}
|
|
if return_scores:
|
|
result['score'] = float(similarities[idx])
|
|
results.append(result)
|
|
|
|
return results
|
|
|
|
def batch_search(
|
|
self,
|
|
query_image_paths: List[str],
|
|
top_k: int = DEFAULT_TOP_K,
|
|
exclude_self: bool = True
|
|
) -> Dict[str, List[Dict]]:
|
|
"""
|
|
Search for multiple query images at once
|
|
|
|
Args:
|
|
query_image_paths: List of paths to query images
|
|
top_k: Number of results per query
|
|
exclude_self: Whether to exclude query images from results
|
|
|
|
Returns:
|
|
Dictionary mapping query path to list of results
|
|
"""
|
|
logger.info(f"Batch searching {len(query_image_paths)} images")
|
|
|
|
results = {}
|
|
for query_path in query_image_paths:
|
|
try:
|
|
results[query_path] = self.search_by_image_path(
|
|
query_path,
|
|
top_k=top_k,
|
|
exclude_self=exclude_self,
|
|
return_scores=True
|
|
)
|
|
except Exception as e:
|
|
logger.error(f"Failed to search for {query_path}: {e}")
|
|
results[query_path] = []
|
|
|
|
return results
|
|
|
|
def get_similarity_stats(self, query_image_path: str) -> Dict:
|
|
"""
|
|
Get statistics about similarity scores for a query image
|
|
|
|
Args:
|
|
query_image_path: Path to query image
|
|
|
|
Returns:
|
|
Dictionary with similarity statistics
|
|
"""
|
|
try:
|
|
query_img = Image.open(query_image_path).convert("RGB")
|
|
query_embedding = self.image_embedder.image_to_embedding(query_img).cpu().numpy()
|
|
except Exception as e:
|
|
raise RuntimeError(f"Failed to process query image: {e}")
|
|
|
|
image_embeddings = np.stack(list(self.image_data.values()))
|
|
similarities = cosine_similarity(
|
|
query_embedding.reshape(1, -1),
|
|
image_embeddings
|
|
).flatten()
|
|
|
|
return {
|
|
'mean_similarity': float(np.mean(similarities)),
|
|
'max_similarity': float(np.max(similarities)),
|
|
'min_similarity': float(np.min(similarities)),
|
|
'std_similarity': float(np.std(similarities)),
|
|
'median_similarity': float(np.median(similarities))
|
|
}
|
|
|
|
def get_memory_usage(self):
|
|
"""Get memory usage information"""
|
|
return self.image_embedder.get_memory_usage()
|
|
|
|
|
|
def print_results(results: List[Dict], query: str, show_scores: bool = True):
|
|
"""Pretty print search results"""
|
|
print(f"\nTop {len(results)} similar images to: {Path(query).name}")
|
|
print(f"Query: {query}")
|
|
print("-" * 100)
|
|
|
|
for i, result in enumerate(results, 1):
|
|
path = result['path']
|
|
if show_scores and 'score' in result:
|
|
score = result['score']
|
|
print(f"{i:2d}. {path:<80} score={score:>7.4f}")
|
|
else:
|
|
print(f"{i:2d}. {path}")
|
|
|
|
print("-" * 100)
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(description="Image-to-Image Similarity Search")
|
|
|
|
# Input/Output paths
|
|
parser.add_argument(
|
|
"--image-embeddings",
|
|
type=str,
|
|
default=str(DEFAULT_IMG_EMB_PATH),
|
|
help="Path to image embeddings pickle file"
|
|
)
|
|
|
|
# Query options
|
|
parser.add_argument(
|
|
"--query-image",
|
|
type=str,
|
|
help="Path to query image"
|
|
)
|
|
parser.add_argument(
|
|
"--query-images-file",
|
|
type=str,
|
|
help="File with one image path per line"
|
|
)
|
|
parser.add_argument(
|
|
"--query-dir",
|
|
type=str,
|
|
help="Directory containing query images"
|
|
)
|
|
|
|
# Model parameters
|
|
parser.add_argument(
|
|
"--vision-model",
|
|
type=str,
|
|
default=DEFAULT_VISION_MODEL,
|
|
help="CLIP vision model to use"
|
|
)
|
|
parser.add_argument(
|
|
"--embedding-dim",
|
|
type=int,
|
|
default=4096,
|
|
help="Embedding dimension"
|
|
)
|
|
parser.add_argument(
|
|
"--llava-ckpt",
|
|
type=str,
|
|
help="Path to LLaVA checkpoint for projection weights"
|
|
)
|
|
|
|
# Search parameters
|
|
parser.add_argument(
|
|
"--top-k",
|
|
type=int,
|
|
default=DEFAULT_TOP_K,
|
|
help="Number of top results to return"
|
|
)
|
|
parser.add_argument(
|
|
"--include-self",
|
|
action="store_true",
|
|
help="Include query image in results"
|
|
)
|
|
|
|
# Output options
|
|
parser.add_argument(
|
|
"--output",
|
|
type=str,
|
|
help="Save results to JSON file"
|
|
)
|
|
parser.add_argument(
|
|
"--stats",
|
|
action="store_true",
|
|
help="Show similarity statistics"
|
|
)
|
|
|
|
args = parser.parse_args()
|
|
|
|
# Initialize searcher
|
|
try:
|
|
searcher = ImageSimilaritySearcher(
|
|
image_embeddings_path=args.image_embeddings,
|
|
vision_model_name=args.vision_model,
|
|
embedding_dim=args.embedding_dim,
|
|
llava_ckpt_path=args.llava_ckpt
|
|
)
|
|
except Exception as e:
|
|
logger.error(f"Failed to initialize searcher: {e}")
|
|
return
|
|
|
|
# Show memory usage
|
|
mem = searcher.get_memory_usage()
|
|
logger.info(f"GPU Memory - Allocated: {mem['allocated']:.2f}GB, Reserved: {mem['reserved']:.2f}GB")
|
|
|
|
# Handle different query modes
|
|
if args.query_image:
|
|
# Single query image
|
|
results = searcher.search_by_image_path(
|
|
args.query_image,
|
|
top_k=args.top_k,
|
|
exclude_self=not args.include_self,
|
|
return_scores=True
|
|
)
|
|
print_results(results, args.query_image)
|
|
|
|
if args.stats:
|
|
stats = searcher.get_similarity_stats(args.query_image)
|
|
print(f"\nSimilarity Statistics:")
|
|
for key, value in stats.items():
|
|
print(f" {key}: {value:.4f}")
|
|
|
|
elif args.query_images_file:
|
|
# Batch queries from file
|
|
try:
|
|
with open(args.query_images_file, 'r', encoding='utf-8') as f:
|
|
query_paths = [line.strip() for line in f if line.strip()]
|
|
|
|
batch_results = searcher.batch_search(
|
|
query_paths,
|
|
top_k=args.top_k,
|
|
exclude_self=not args.include_self
|
|
)
|
|
|
|
for query_path, results in batch_results.items():
|
|
if results: # Only show if we got results
|
|
print_results(results, query_path)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to process queries file: {e}")
|
|
return
|
|
|
|
elif args.query_dir:
|
|
# Query all images in directory
|
|
query_dir = Path(args.query_dir)
|
|
if not query_dir.exists():
|
|
logger.error(f"Query directory does not exist: {query_dir}")
|
|
return
|
|
|
|
# Find images in directory
|
|
image_exts = {'.jpg', '.jpeg', '.png', '.bmp', '.gif', '.webp', '.tiff', '.tif'}
|
|
query_paths = []
|
|
for ext in image_exts:
|
|
query_paths.extend(query_dir.glob(f"*{ext}"))
|
|
query_paths.extend(query_dir.glob(f"*{ext.upper()}"))
|
|
|
|
query_paths = [str(p) for p in sorted(query_paths)]
|
|
|
|
if not query_paths:
|
|
logger.error(f"No images found in {query_dir}")
|
|
return
|
|
|
|
logger.info(f"Found {len(query_paths)} images in {query_dir}")
|
|
|
|
batch_results = searcher.batch_search(
|
|
query_paths,
|
|
top_k=args.top_k,
|
|
exclude_self=not args.include_self
|
|
)
|
|
|
|
for query_path, results in batch_results.items():
|
|
if results: # Only show if we got results
|
|
print_results(results, query_path)
|
|
|
|
else:
|
|
print("Please specify a query image, query images file, or query directory")
|
|
print("Use --help for more information")
|
|
return
|
|
|
|
# Save results if requested
|
|
if args.output and 'batch_results' in locals():
|
|
import json
|
|
|
|
try:
|
|
with open(args.output, 'w', encoding='utf-8') as f:
|
|
json.dump(batch_results, f, indent=2, ensure_ascii=False)
|
|
logger.info(f"Results saved to {args.output}")
|
|
except Exception as e:
|
|
logger.error(f"Failed to save results: {e}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|