357 lines
11 KiB
Python
357 lines
11 KiB
Python
import argparse
|
|
import pickle
|
|
import logging
|
|
from pathlib import Path
|
|
from typing import Dict, List, Optional
|
|
|
|
import numpy as np
|
|
import torch
|
|
|
|
from text_embedder import CLIPTextEmbedder
|
|
|
|
# Configure logging
|
|
logging.basicConfig(level=logging.INFO)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Default paths
|
|
DEFAULT_IMG_EMB_PATH = Path("embeddings/image_embeddings.pkl")
|
|
DEFAULT_TEXT_MODEL = "openai/clip-vit-base-patch32"
|
|
DEFAULT_TOP_K = 10
|
|
|
|
|
|
class TextToImageSearcher:
|
|
"""
|
|
Text-to-image search using CLIP embeddings
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
image_embeddings_path: Path,
|
|
text_model_name: str = DEFAULT_TEXT_MODEL,
|
|
embedding_dim: int = 4096
|
|
):
|
|
"""
|
|
Initialize the Text-to-Image searcher
|
|
|
|
Args:
|
|
image_embeddings_path: Path to the pickle file containing image embeddings
|
|
text_model_name: Name of the CLIP text model to use
|
|
embedding_dim: Dimension of the embeddings (should match image embeddings)
|
|
"""
|
|
self.image_embeddings_path = Path(image_embeddings_path)
|
|
self.embedding_dim = embedding_dim
|
|
|
|
# Load image embeddings
|
|
logger.info(f"Loading image embeddings from {self.image_embeddings_path}")
|
|
self.image_data = self._load_image_embeddings()
|
|
|
|
# Initialize text embedder
|
|
logger.info(f"Initializing text embedder: {text_model_name}")
|
|
self.text_embedder = CLIPTextEmbedder(
|
|
model_name=text_model_name,
|
|
output_dim=embedding_dim
|
|
)
|
|
|
|
logger.info(f"TextToImageSearcher ready with {len(self.image_data)} images")
|
|
|
|
def _load_image_embeddings(self) -> Dict[str, np.ndarray]:
|
|
"""Load image embeddings from pickle file"""
|
|
try:
|
|
with open(self.image_embeddings_path, "rb") as f:
|
|
data = pickle.load(f)
|
|
|
|
if not isinstance(data, dict):
|
|
raise ValueError("Image embeddings file should contain a dictionary")
|
|
|
|
# Verify embedding dimensions
|
|
sample_key = next(iter(data))
|
|
sample_emb = data[sample_key]
|
|
if sample_emb.shape[-1] != self.embedding_dim:
|
|
raise ValueError(f"Expected {self.embedding_dim}D embeddings, got {sample_emb.shape[-1]}D")
|
|
|
|
logger.info(f"Loaded {len(data)} image embeddings of dimension {sample_emb.shape[-1]}")
|
|
return data
|
|
|
|
except FileNotFoundError:
|
|
raise FileNotFoundError(f"Image embeddings file not found: {self.image_embeddings_path}")
|
|
except Exception as e:
|
|
raise RuntimeError(f"Failed to load image embeddings: {e}")
|
|
|
|
def search(
|
|
self,
|
|
query_text: str,
|
|
top_k: int = DEFAULT_TOP_K,
|
|
return_scores: bool = False
|
|
) -> List[Dict]:
|
|
"""
|
|
Search for images matching the query text
|
|
|
|
Args:
|
|
query_text: Text description to search for
|
|
top_k: Number of top results to return
|
|
return_scores: Whether to include similarity scores
|
|
|
|
Returns:
|
|
List of dictionaries with 'path' and optionally 'score' keys
|
|
"""
|
|
logger.info(f"Searching for: '{query_text}'")
|
|
|
|
# Embed query text
|
|
query_embedding = self.text_embedder.embed_single_text(query_text)
|
|
|
|
# Prepare image data for comparison
|
|
image_paths = list(self.image_data.keys())
|
|
image_embeddings = np.stack(list(self.image_data.values())) # (N, D)
|
|
|
|
# Compute cosine similarities (both embeddings are normalized)
|
|
similarities = np.dot(image_embeddings, query_embedding)
|
|
|
|
# Get top-k indices
|
|
top_k_indices = np.argsort(-similarities)[:top_k]
|
|
|
|
# Prepare results
|
|
results = []
|
|
for idx in top_k_indices:
|
|
result = {'path': image_paths[idx]}
|
|
if return_scores:
|
|
result['score'] = float(similarities[idx])
|
|
results.append(result)
|
|
|
|
return results
|
|
|
|
def batch_search(
|
|
self,
|
|
query_texts: List[str],
|
|
top_k: int = DEFAULT_TOP_K
|
|
) -> Dict[str, List[Dict]]:
|
|
"""
|
|
Search for multiple queries at once
|
|
|
|
Args:
|
|
query_texts: List of text descriptions
|
|
top_k: Number of results per query
|
|
|
|
Returns:
|
|
Dictionary mapping query text to list of results
|
|
"""
|
|
logger.info(f"Batch searching {len(query_texts)} queries")
|
|
|
|
results = {}
|
|
for query in query_texts:
|
|
results[query] = self.search(query, top_k=top_k, return_scores=True)
|
|
|
|
return results
|
|
|
|
def get_search_stats(self, query_text: str) -> Dict:
|
|
"""
|
|
Get statistics about search results
|
|
|
|
Args:
|
|
query_text: Query to analyze
|
|
|
|
Returns:
|
|
Dictionary with similarity statistics
|
|
"""
|
|
query_embedding = self.text_embedder.embed_single_text(query_text)
|
|
image_embeddings = np.stack(list(self.image_data.values()))
|
|
similarities = np.dot(image_embeddings, query_embedding)
|
|
|
|
return {
|
|
'mean_similarity': float(np.mean(similarities)),
|
|
'max_similarity': float(np.max(similarities)),
|
|
'min_similarity': float(np.min(similarities)),
|
|
'std_similarity': float(np.std(similarities)),
|
|
'median_similarity': float(np.median(similarities))
|
|
}
|
|
|
|
def get_memory_usage(self):
|
|
"""Get memory usage information"""
|
|
return self.text_embedder.get_memory_usage()
|
|
|
|
def clear_cache(self):
|
|
"""Clear text embedding cache"""
|
|
self.text_embedder.clear_cache()
|
|
|
|
|
|
def print_results(results: List[Dict], query: str, show_scores: bool = True):
|
|
"""Pretty print search results"""
|
|
print(f"\nTop {len(results)} results for: '{query}'")
|
|
print("-" * 80)
|
|
|
|
for i, result in enumerate(results, 1):
|
|
path = result['path']
|
|
if show_scores and 'score' in result:
|
|
score = result['score']
|
|
print(f"{i:2d}. {path:<60} score={score:>7.4f}")
|
|
else:
|
|
print(f"{i:2d}. {path}")
|
|
|
|
print("-" * 80)
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(description="Text-to-Image Search using CLIP embeddings")
|
|
|
|
# Input/Output paths
|
|
parser.add_argument(
|
|
"--image-embeddings",
|
|
type=str,
|
|
default=str(DEFAULT_IMG_EMB_PATH),
|
|
help="Path to image embeddings pickle file"
|
|
)
|
|
|
|
# Query options
|
|
parser.add_argument(
|
|
"--query",
|
|
type=str,
|
|
help="Single text query to search for"
|
|
)
|
|
parser.add_argument(
|
|
"--queries-file",
|
|
type=str,
|
|
help="File with one query per line"
|
|
)
|
|
parser.add_argument(
|
|
"--interactive",
|
|
action="store_true",
|
|
help="Interactive mode for multiple queries"
|
|
)
|
|
|
|
# Search parameters
|
|
parser.add_argument(
|
|
"--top-k",
|
|
type=int,
|
|
default=DEFAULT_TOP_K,
|
|
help="Number of top results to return"
|
|
)
|
|
parser.add_argument(
|
|
"--text-model",
|
|
type=str,
|
|
default=DEFAULT_TEXT_MODEL,
|
|
help="CLIP text model to use"
|
|
)
|
|
parser.add_argument(
|
|
"--embedding-dim",
|
|
type=int,
|
|
default=4096,
|
|
help="Embedding dimension"
|
|
)
|
|
|
|
# Output options
|
|
parser.add_argument(
|
|
"--output",
|
|
type=str,
|
|
help="Save results to JSON file"
|
|
)
|
|
parser.add_argument(
|
|
"--stats",
|
|
action="store_true",
|
|
help="Show similarity statistics"
|
|
)
|
|
|
|
args = parser.parse_args()
|
|
|
|
# Initialize searcher
|
|
try:
|
|
searcher = TextToImageSearcher(
|
|
image_embeddings_path=args.image_embeddings,
|
|
text_model_name=args.text_model,
|
|
embedding_dim=args.embedding_dim
|
|
)
|
|
except Exception as e:
|
|
logger.error(f"Failed to initialize searcher: {e}")
|
|
return
|
|
|
|
# Show memory usage
|
|
mem = searcher.get_memory_usage()
|
|
logger.info(f"GPU Memory - Allocated: {mem['allocated']:.2f}GB, Reserved: {mem['reserved']:.2f}GB")
|
|
|
|
# Handle different query modes
|
|
if args.query:
|
|
# Single query
|
|
results = searcher.search(args.query, top_k=args.top_k, return_scores=True)
|
|
print_results(results, args.query)
|
|
|
|
if args.stats:
|
|
stats = searcher.get_search_stats(args.query)
|
|
print(f"\nSimilarity Statistics:")
|
|
for key, value in stats.items():
|
|
print(f" {key}: {value:.4f}")
|
|
|
|
elif args.queries_file:
|
|
# Batch queries from file
|
|
try:
|
|
with open(args.queries_file, 'r', encoding='utf-8') as f:
|
|
queries = [line.strip() for line in f if line.strip()]
|
|
|
|
batch_results = searcher.batch_search(queries, top_k=args.top_k)
|
|
|
|
for query, results in batch_results.items():
|
|
print_results(results, query)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to process queries file: {e}")
|
|
return
|
|
|
|
elif args.interactive:
|
|
# Interactive mode
|
|
print("\nInteractive Text-to-Image Search")
|
|
print("Enter text queries (or 'quit' to exit):")
|
|
print("-" * 40)
|
|
|
|
while True:
|
|
try:
|
|
query = input("\nQuery: ").strip()
|
|
if query.lower() in ['quit', 'exit', 'q']:
|
|
break
|
|
|
|
if not query:
|
|
continue
|
|
|
|
results = searcher.search(query, top_k=args.top_k, return_scores=True)
|
|
print_results(results, query)
|
|
|
|
if args.stats:
|
|
stats = searcher.get_search_stats(query)
|
|
print(f"\nStats: mean={stats['mean_similarity']:.3f}, "
|
|
f"max={stats['max_similarity']:.3f}, "
|
|
f"std={stats['std_similarity']:.3f}")
|
|
|
|
except KeyboardInterrupt:
|
|
print("\nExiting...")
|
|
break
|
|
except Exception as e:
|
|
logger.error(f"Error processing query: {e}")
|
|
|
|
else:
|
|
# Demo mode with sample queries
|
|
demo_queries = [
|
|
"a cat sleeping on a bed",
|
|
]
|
|
|
|
print("Demo mode: Running sample queries...")
|
|
batch_results = searcher.batch_search(demo_queries, top_k=5)
|
|
|
|
for query, results in batch_results.items():
|
|
print_results(results, query)
|
|
|
|
# Save results if requested
|
|
if args.output and 'batch_results' in locals():
|
|
import json
|
|
|
|
# Convert results to JSON-serializable format
|
|
json_results = {}
|
|
for query, results in batch_results.items():
|
|
json_results[query] = results # Results are already dict format
|
|
|
|
try:
|
|
with open(args.output, 'w', encoding='utf-8') as f:
|
|
json.dump(json_results, f, indent=2, ensure_ascii=False)
|
|
logger.info(f"Results saved to {args.output}")
|
|
except Exception as e:
|
|
logger.error(f"Failed to save results: {e}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|