Img2Vec/text_search_starter.py

357 lines
11 KiB
Python

import argparse
import pickle
import logging
from pathlib import Path
from typing import Dict, List, Optional
import numpy as np
import torch
from text_embedder import CLIPTextEmbedder
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Default paths
DEFAULT_IMG_EMB_PATH = Path("embeddings/image_embeddings.pkl")
DEFAULT_TEXT_MODEL = "openai/clip-vit-base-patch32"
DEFAULT_TOP_K = 10
class TextToImageSearcher:
"""
Text-to-image search using CLIP embeddings
"""
def __init__(
self,
image_embeddings_path: Path,
text_model_name: str = DEFAULT_TEXT_MODEL,
embedding_dim: int = 4096
):
"""
Initialize the Text-to-Image searcher
Args:
image_embeddings_path: Path to the pickle file containing image embeddings
text_model_name: Name of the CLIP text model to use
embedding_dim: Dimension of the embeddings (should match image embeddings)
"""
self.image_embeddings_path = Path(image_embeddings_path)
self.embedding_dim = embedding_dim
# Load image embeddings
logger.info(f"Loading image embeddings from {self.image_embeddings_path}")
self.image_data = self._load_image_embeddings()
# Initialize text embedder
logger.info(f"Initializing text embedder: {text_model_name}")
self.text_embedder = CLIPTextEmbedder(
model_name=text_model_name,
output_dim=embedding_dim
)
logger.info(f"TextToImageSearcher ready with {len(self.image_data)} images")
def _load_image_embeddings(self) -> Dict[str, np.ndarray]:
"""Load image embeddings from pickle file"""
try:
with open(self.image_embeddings_path, "rb") as f:
data = pickle.load(f)
if not isinstance(data, dict):
raise ValueError("Image embeddings file should contain a dictionary")
# Verify embedding dimensions
sample_key = next(iter(data))
sample_emb = data[sample_key]
if sample_emb.shape[-1] != self.embedding_dim:
raise ValueError(f"Expected {self.embedding_dim}D embeddings, got {sample_emb.shape[-1]}D")
logger.info(f"Loaded {len(data)} image embeddings of dimension {sample_emb.shape[-1]}")
return data
except FileNotFoundError:
raise FileNotFoundError(f"Image embeddings file not found: {self.image_embeddings_path}")
except Exception as e:
raise RuntimeError(f"Failed to load image embeddings: {e}")
def search(
self,
query_text: str,
top_k: int = DEFAULT_TOP_K,
return_scores: bool = False
) -> List[Dict]:
"""
Search for images matching the query text
Args:
query_text: Text description to search for
top_k: Number of top results to return
return_scores: Whether to include similarity scores
Returns:
List of dictionaries with 'path' and optionally 'score' keys
"""
logger.info(f"Searching for: '{query_text}'")
# Embed query text
query_embedding = self.text_embedder.embed_single_text(query_text)
# Prepare image data for comparison
image_paths = list(self.image_data.keys())
image_embeddings = np.stack(list(self.image_data.values())) # (N, D)
# Compute cosine similarities (both embeddings are normalized)
similarities = np.dot(image_embeddings, query_embedding)
# Get top-k indices
top_k_indices = np.argsort(-similarities)[:top_k]
# Prepare results
results = []
for idx in top_k_indices:
result = {'path': image_paths[idx]}
if return_scores:
result['score'] = float(similarities[idx])
results.append(result)
return results
def batch_search(
self,
query_texts: List[str],
top_k: int = DEFAULT_TOP_K
) -> Dict[str, List[Dict]]:
"""
Search for multiple queries at once
Args:
query_texts: List of text descriptions
top_k: Number of results per query
Returns:
Dictionary mapping query text to list of results
"""
logger.info(f"Batch searching {len(query_texts)} queries")
results = {}
for query in query_texts:
results[query] = self.search(query, top_k=top_k, return_scores=True)
return results
def get_search_stats(self, query_text: str) -> Dict:
"""
Get statistics about search results
Args:
query_text: Query to analyze
Returns:
Dictionary with similarity statistics
"""
query_embedding = self.text_embedder.embed_single_text(query_text)
image_embeddings = np.stack(list(self.image_data.values()))
similarities = np.dot(image_embeddings, query_embedding)
return {
'mean_similarity': float(np.mean(similarities)),
'max_similarity': float(np.max(similarities)),
'min_similarity': float(np.min(similarities)),
'std_similarity': float(np.std(similarities)),
'median_similarity': float(np.median(similarities))
}
def get_memory_usage(self):
"""Get memory usage information"""
return self.text_embedder.get_memory_usage()
def clear_cache(self):
"""Clear text embedding cache"""
self.text_embedder.clear_cache()
def print_results(results: List[Dict], query: str, show_scores: bool = True):
"""Pretty print search results"""
print(f"\nTop {len(results)} results for: '{query}'")
print("-" * 80)
for i, result in enumerate(results, 1):
path = result['path']
if show_scores and 'score' in result:
score = result['score']
print(f"{i:2d}. {path:<60} score={score:>7.4f}")
else:
print(f"{i:2d}. {path}")
print("-" * 80)
def main():
parser = argparse.ArgumentParser(description="Text-to-Image Search using CLIP embeddings")
# Input/Output paths
parser.add_argument(
"--image-embeddings",
type=str,
default=str(DEFAULT_IMG_EMB_PATH),
help="Path to image embeddings pickle file"
)
# Query options
parser.add_argument(
"--query",
type=str,
help="Single text query to search for"
)
parser.add_argument(
"--queries-file",
type=str,
help="File with one query per line"
)
parser.add_argument(
"--interactive",
action="store_true",
help="Interactive mode for multiple queries"
)
# Search parameters
parser.add_argument(
"--top-k",
type=int,
default=DEFAULT_TOP_K,
help="Number of top results to return"
)
parser.add_argument(
"--text-model",
type=str,
default=DEFAULT_TEXT_MODEL,
help="CLIP text model to use"
)
parser.add_argument(
"--embedding-dim",
type=int,
default=4096,
help="Embedding dimension"
)
# Output options
parser.add_argument(
"--output",
type=str,
help="Save results to JSON file"
)
parser.add_argument(
"--stats",
action="store_true",
help="Show similarity statistics"
)
args = parser.parse_args()
# Initialize searcher
try:
searcher = TextToImageSearcher(
image_embeddings_path=args.image_embeddings,
text_model_name=args.text_model,
embedding_dim=args.embedding_dim
)
except Exception as e:
logger.error(f"Failed to initialize searcher: {e}")
return
# Show memory usage
mem = searcher.get_memory_usage()
logger.info(f"GPU Memory - Allocated: {mem['allocated']:.2f}GB, Reserved: {mem['reserved']:.2f}GB")
# Handle different query modes
if args.query:
# Single query
results = searcher.search(args.query, top_k=args.top_k, return_scores=True)
print_results(results, args.query)
if args.stats:
stats = searcher.get_search_stats(args.query)
print(f"\nSimilarity Statistics:")
for key, value in stats.items():
print(f" {key}: {value:.4f}")
elif args.queries_file:
# Batch queries from file
try:
with open(args.queries_file, 'r', encoding='utf-8') as f:
queries = [line.strip() for line in f if line.strip()]
batch_results = searcher.batch_search(queries, top_k=args.top_k)
for query, results in batch_results.items():
print_results(results, query)
except Exception as e:
logger.error(f"Failed to process queries file: {e}")
return
elif args.interactive:
# Interactive mode
print("\nInteractive Text-to-Image Search")
print("Enter text queries (or 'quit' to exit):")
print("-" * 40)
while True:
try:
query = input("\nQuery: ").strip()
if query.lower() in ['quit', 'exit', 'q']:
break
if not query:
continue
results = searcher.search(query, top_k=args.top_k, return_scores=True)
print_results(results, query)
if args.stats:
stats = searcher.get_search_stats(query)
print(f"\nStats: mean={stats['mean_similarity']:.3f}, "
f"max={stats['max_similarity']:.3f}, "
f"std={stats['std_similarity']:.3f}")
except KeyboardInterrupt:
print("\nExiting...")
break
except Exception as e:
logger.error(f"Error processing query: {e}")
else:
# Demo mode with sample queries
demo_queries = [
"a cat sleeping on a bed",
]
print("Demo mode: Running sample queries...")
batch_results = searcher.batch_search(demo_queries, top_k=5)
for query, results in batch_results.items():
print_results(results, query)
# Save results if requested
if args.output and 'batch_results' in locals():
import json
# Convert results to JSON-serializable format
json_results = {}
for query, results in batch_results.items():
json_results[query] = results # Results are already dict format
try:
with open(args.output, 'w', encoding='utf-8') as f:
json.dump(json_results, f, indent=2, ensure_ascii=False)
logger.info(f"Results saved to {args.output}")
except Exception as e:
logger.error(f"Failed to save results: {e}")
if __name__ == "__main__":
main()