From 2103e28aea7e44785790a6a168a42462c24eb79e Mon Sep 17 00:00:00 2001 From: ldy Date: Wed, 18 Jun 2025 16:02:01 +0800 Subject: [PATCH] Rough implementation of 5120->4096 dim --- embedder.py | 182 +++++++++++++--- find_similar_img.py | 460 ++++++++++++++++++++++++++++++++++++++--- starter.py | 242 +++++++++++++++++----- text_embedder.py | 306 +++++++++++++++++++++++++++ text_search_starter.py | 356 +++++++++++++++++++++++++++++++ 5 files changed, 1440 insertions(+), 106 deletions(-) create mode 100644 text_embedder.py create mode 100644 text_search_starter.py diff --git a/embedder.py b/embedder.py index 2b4fb1a..59e0a34 100644 --- a/embedder.py +++ b/embedder.py @@ -1,17 +1,31 @@ import torch +import torch.nn as nn from transformers import CLIPImageProcessor, CLIPVisionModel from PIL import Image +import logging -# Use CUDA +# Configure logging +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +# Use CUDA if available DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") class ImageEmbedder: + """ + Image embedder that outputs 4096-dimensional embeddings using CLIP vision model. + Maintains compatibility with existing LLaVA weights while reducing dimension. + """ + def __init__(self, vision_model_name: str = "openai/clip-vit-large-patch14-336", - proj_out_dim: int = 5120, + proj_out_dim: int = 4096, # Changed from 5120 to 4096 llava_ckpt_path: str = None): - # Load CLIP vision encoder + processor + + logger.info(f"Initializing ImageEmbedder with {vision_model_name}") + + # Load CLIP vision encoder + processor (keep exactly the same) self.processor = CLIPImageProcessor.from_pretrained(vision_model_name) self.vision_model = ( CLIPVisionModel @@ -19,37 +33,157 @@ class ImageEmbedder: .to(DEVICE) .eval() ) - # Freeze version + + # Freeze vision model parameters for p in self.vision_model.parameters(): p.requires_grad = False - # Build the projection layer (1024 → proj_out_dim) and move it to DEVICE + # Get vision model's hidden dimension vision_hidden_dim = self.vision_model.config.hidden_size # should be 1024 - self.projection = torch.nn.Linear(vision_hidden_dim, proj_out_dim).to(DEVICE) + + # Keep the simple projection architecture that works! + # Just change output dimension from 5120 to 4096 + self.projection = nn.Linear(vision_hidden_dim, proj_out_dim).to(DEVICE) self.projection.eval() - # Load LLaVA’s projection weights from the full bin checkpoint + # Load LLaVA's projection weights if provided if llava_ckpt_path is not None: - ckpt = torch.load(llava_ckpt_path, map_location=DEVICE) - # extract only the projector weights + bias + self._load_llava_weights(llava_ckpt_path, vision_hidden_dim, proj_out_dim) + + logger.info(f"ImageEmbedder ready: {vision_hidden_dim}D -> {proj_out_dim}D on {DEVICE}") + + def _load_llava_weights(self, ckpt_path, vision_dim, output_dim): + """Load LLaVA projection weights, adapting for different output dimension""" + try: + ckpt = torch.load(ckpt_path, map_location=DEVICE) + + # Extract projector weights (same as your original code) keys = ["model.mm_projector.weight", "model.mm_projector.bias"] - state = { - k.replace("model.mm_projector.", ""): ckpt[k] - for k in keys - } - # Load into our linear layer - self.projection.load_state_dict(state) - def image_to_embedding(self, image: Image.Image) -> torch.Tensor: - # preprocess & move to device - inputs = self.processor(images=image, return_tensors="pt") - pixel_values = inputs.pixel_values.to(DEVICE) # (1,3,H,W) + if all(k in ckpt for k in keys): + original_weight = ckpt["model.mm_projector.weight"] # Shape: (5120, 1024) + original_bias = ckpt["model.mm_projector.bias"] # Shape: (5120,) - # Forward pass + if output_dim == 5120: + # Direct loading for original dimension + state = { + "weight": original_weight, + "bias": original_bias + } + self.projection.load_state_dict(state) + logger.info("Loaded original LLaVA projection weights") + + elif output_dim < 5120: + # Truncate weights for smaller dimension (preserves most important features) + truncated_weight = original_weight[:output_dim, :] # (4096, 1024) + truncated_bias = original_bias[:output_dim] # (4096,) + + state = { + "weight": truncated_weight, + "bias": truncated_bias + } + self.projection.load_state_dict(state) + logger.info(f"Loaded truncated LLaVA weights: {5120}D -> {output_dim}D") + + else: + # For larger dimensions, pad with Xavier initialization + new_weight = torch.zeros(output_dim, vision_dim, device=DEVICE) + new_bias = torch.zeros(output_dim, device=DEVICE) + + # Copy original weights + new_weight[:5120, :] = original_weight + new_bias[:5120] = original_bias + + # Initialize additional weights + nn.init.xavier_uniform_(new_weight[5120:, :]) + nn.init.zeros_(new_bias[5120:]) + + state = { + "weight": new_weight, + "bias": new_bias + } + self.projection.load_state_dict(state) + logger.info(f"Loaded extended LLaVA weights: {5120}D -> {output_dim}D") + else: + logger.warning("LLaVA projector weights not found, using random initialization") + + except Exception as e: + logger.warning(f"Failed to load LLaVA weights: {e}, using random initialization") + + def image_to_embedding(self, image_input) -> torch.Tensor: + """ + Convert image(s) to embeddings - keeping your exact original logic! + + Args: + image_input: PIL.Image or list of PIL.Images + + Returns: + torch.Tensor: Embeddings of shape (proj_out_dim,) for single image + or (B, proj_out_dim) for batch of images + """ + # Handle single image vs batch (same as your original) + if isinstance(image_input, Image.Image): + images = [image_input] + single_image = True + else: + images = image_input + single_image = False + + # Preprocess images (same as your original) + inputs = self.processor(images=images, return_tensors="pt") + pixel_values = inputs.pixel_values.to(DEVICE) # (B, 3, H, W) + + # Forward pass (exactly like your original) with torch.no_grad(): out = self.vision_model(pixel_values=pixel_values) - feat = out.pooler_output # (1,1024) - emb = self.projection(feat) # (1,proj_out_dim) + feat = out.pooler_output # (B, 1024) + emb = self.projection(feat) # (B, proj_out_dim) - # Returns a 1D tensor of size [proj_out_dim] on DEVICE. - return emb.squeeze(0) # → (proj_out_dim,) + # Return appropriate shape (same as your original) + if single_image: + return emb.squeeze(0) # (proj_out_dim,) + else: + return emb # (B, proj_out_dim) + + def get_memory_usage(self): + """Get current GPU memory usage if available""" + if torch.cuda.is_available(): + return { + 'allocated': torch.cuda.memory_allocated() / 1024 ** 3, # GB + 'reserved': torch.cuda.memory_reserved() / 1024 ** 3, # GB + } + return {'allocated': 0, 'reserved': 0} + + +# Test function +def main(): + """Simple test of the embedder""" + import argparse + from pathlib import Path + + parser = argparse.ArgumentParser(description="Test ImageEmbedder") + parser.add_argument("--image", type=str, help="Path to test image") + parser.add_argument("--model", type=str, default="openai/clip-vit-large-patch14-336") + parser.add_argument("--dim", type=int, default=4096) + parser.add_argument("--llava-ckpt", type=str, help="Path to LLaVA checkpoint") + args = parser.parse_args() + + embedder = ImageEmbedder( + vision_model_name=args.model, + proj_out_dim=args.dim, + llava_ckpt_path=args.llava_ckpt + ) + + if args.image: + img = Image.open(args.image).convert("RGB") + emb = embedder.image_to_embedding(img) + print(f"Embedding shape: {emb.shape}") + print(f"Sample values: {emb[:10]}") + + # Show memory usage + mem = embedder.get_memory_usage() + logger.info(f"GPU Memory - Allocated: {mem['allocated']:.2f}GB, Reserved: {mem['reserved']:.2f}GB") + + +if __name__ == "__main__": + main() diff --git a/find_similar_img.py b/find_similar_img.py index 6df859e..8311709 100644 --- a/find_similar_img.py +++ b/find_similar_img.py @@ -1,36 +1,444 @@ -import numpy as np +import argparse import pickle +import logging +from pathlib import Path +from typing import Dict, List, Optional + +import numpy as np from PIL import Image from sklearn.metrics.pairwise import cosine_similarity from embedder import ImageEmbedder -# ——— Load stored embeddings & mapping ——— -vecs = np.load("imgVecs/image_vectors.npy") # (N, D) -with open("imgVecs/index_to_file.pkl", "rb") as f: - idx2file = pickle.load(f) # dict: idx → filepath +# Configure logging +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) -# ——— Specify query image ——— -query_path = "datasets/coco/val2017/1140002154.jpg" +# Default paths +DEFAULT_IMG_EMB_PATH = Path("embeddings/image_embeddings.pkl") +DEFAULT_VISION_MODEL = "openai/clip-vit-large-patch14-336" +DEFAULT_TOP_K = 10 -# ——— Embed the query image ——— -embedder = ImageEmbedder( - vision_model_name="openai/clip-vit-large-patch14-336", - proj_out_dim=5120, - llava_ckpt_path="datasets/pytorch_model-00003-of-00003.bin" -) -img = Image.open(query_path).convert("RGB") -q_vec = embedder.image_to_embedding(img).cpu().numpy() # (D,) -# ——— Compute similarities & retrieve top-k ——— -# (you can normalize if you built your DB with inner-product indexing) -# vecs_norm = vecs / np.linalg.norm(vecs, axis=1, keepdims=True) -# q_vec_norm = q_vec / np.linalg.norm(q_vec) -# sims = cosine_similarity(q_vec_norm.reshape(1, -1), vecs_norm).flatten() -sims = cosine_similarity(q_vec.reshape(1, -1), vecs).flatten() -top5 = sims.argsort()[-5:][::-1] +class ImageSimilaritySearcher: + """ + Image-to-image similarity search using unified embeddings dictionary + """ -# ——— Print out the results ——— -print(f"Query image: {query_path}\n") -for rank, idx in enumerate(top5, 1): - print(f"{rank:>2}. {idx2file[idx]} (score: {sims[idx]:.4f})") + def __init__( + self, + image_embeddings_path: Path, + vision_model_name: str = DEFAULT_VISION_MODEL, + embedding_dim: int = 4096, + llava_ckpt_path: Optional[str] = None + ): + self.image_embeddings_path = Path(image_embeddings_path) + self.embedding_dim = embedding_dim + + # Load image embeddings + logger.info(f"Loading image embeddings from {self.image_embeddings_path}") + self.image_data = self._load_image_embeddings() + + # Initialize image embedder for query images + logger.info(f"Initializing image embedder: {vision_model_name}") + self.image_embedder = ImageEmbedder( + vision_model_name=vision_model_name, + proj_out_dim=embedding_dim, + llava_ckpt_path=llava_ckpt_path + ) + + logger.info(f"ImageSimilaritySearcher ready with {len(self.image_data)} images") + + def _load_image_embeddings(self) -> Dict[str, np.ndarray]: + """Load image embeddings from unified pickle file""" + try: + with open(self.image_embeddings_path, "rb") as f: + data = pickle.load(f) + + if not isinstance(data, dict): + raise ValueError("Image embeddings file should contain a dictionary") + + # Verify embedding dimensions + if data: + sample_key = next(iter(data)) + sample_emb = data[sample_key] + if sample_emb.shape[-1] != self.embedding_dim: + raise ValueError(f"Expected {self.embedding_dim}D embeddings, got {sample_emb.shape[-1]}D") + + logger.info(f"Loaded {len(data)} image embeddings of dimension {sample_emb.shape[-1]}") + else: + logger.warning("Empty embeddings dictionary loaded") + + return data + + except FileNotFoundError: + raise FileNotFoundError(f"Image embeddings file not found: {self.image_embeddings_path}") + except Exception as e: + raise RuntimeError(f"Failed to load image embeddings: {e}") + + def search_by_image_path( + self, + query_image_path: str, + top_k: int = DEFAULT_TOP_K, + exclude_self: bool = True, + return_scores: bool = False + ) -> List[Dict]: + """ + Search for similar images given a query image path + + Args: + query_image_path: Path to query image + top_k: Number of top results to return + exclude_self: Whether to exclude the query image from results + return_scores: Whether to include similarity scores + + Returns: + List of dictionaries with 'path' and optionally 'score' keys + """ + logger.info(f"Searching for images similar to: {query_image_path}") + + # Load and embed query image + try: + query_img = Image.open(query_image_path).convert("RGB") + query_embedding = self.image_embedder.image_to_embedding(query_img).cpu().numpy() + except Exception as e: + raise RuntimeError(f"Failed to process query image {query_image_path}: {e}") + + return self._search_by_embedding( + query_embedding, + query_image_path if exclude_self else None, + top_k, + return_scores + ) + + def search_by_image( + self, + query_image: Image.Image, + top_k: int = DEFAULT_TOP_K, + return_scores: bool = False + ) -> List[Dict]: + """ + Search for similar images given a PIL Image + + Args: + query_image: PIL Image object + top_k: Number of top results to return + return_scores: Whether to include similarity scores + + Returns: + List of dictionaries with 'path' and optionally 'score' keys + """ + logger.info("Searching for similar images using provided image") + + # Embed query image + query_embedding = self.image_embedder.image_to_embedding(query_image).cpu().numpy() + + return self._search_by_embedding(query_embedding, None, top_k, return_scores) + + def _search_by_embedding( + self, + query_embedding: np.ndarray, + exclude_path: Optional[str], + top_k: int, + return_scores: bool + ) -> List[Dict]: + """Internal method to search by embedding vector""" + + # Prepare database embeddings + image_paths = list(self.image_data.keys()) + image_embeddings = np.stack(list(self.image_data.values())) # (N, D) + + # Compute cosine similarities + similarities = cosine_similarity( + query_embedding.reshape(1, -1), + image_embeddings + ).flatten() + + # Handle exclusion + valid_indices = list(range(len(image_paths))) + if exclude_path: + try: + exclude_idx = image_paths.index(exclude_path) + valid_indices.remove(exclude_idx) + similarities = np.delete(similarities, exclude_idx) + image_paths = [p for i, p in enumerate(image_paths) if i != exclude_idx] + except ValueError: + pass # Query path not in database + + # Get top-k indices + top_k_indices = np.argsort(-similarities)[:top_k] + + # Prepare results + results = [] + for idx in top_k_indices: + result = {'path': image_paths[idx]} + if return_scores: + result['score'] = float(similarities[idx]) + results.append(result) + + return results + + def batch_search( + self, + query_image_paths: List[str], + top_k: int = DEFAULT_TOP_K, + exclude_self: bool = True + ) -> Dict[str, List[Dict]]: + """ + Search for multiple query images at once + + Args: + query_image_paths: List of paths to query images + top_k: Number of results per query + exclude_self: Whether to exclude query images from results + + Returns: + Dictionary mapping query path to list of results + """ + logger.info(f"Batch searching {len(query_image_paths)} images") + + results = {} + for query_path in query_image_paths: + try: + results[query_path] = self.search_by_image_path( + query_path, + top_k=top_k, + exclude_self=exclude_self, + return_scores=True + ) + except Exception as e: + logger.error(f"Failed to search for {query_path}: {e}") + results[query_path] = [] + + return results + + def get_similarity_stats(self, query_image_path: str) -> Dict: + """ + Get statistics about similarity scores for a query image + + Args: + query_image_path: Path to query image + + Returns: + Dictionary with similarity statistics + """ + try: + query_img = Image.open(query_image_path).convert("RGB") + query_embedding = self.image_embedder.image_to_embedding(query_img).cpu().numpy() + except Exception as e: + raise RuntimeError(f"Failed to process query image: {e}") + + image_embeddings = np.stack(list(self.image_data.values())) + similarities = cosine_similarity( + query_embedding.reshape(1, -1), + image_embeddings + ).flatten() + + return { + 'mean_similarity': float(np.mean(similarities)), + 'max_similarity': float(np.max(similarities)), + 'min_similarity': float(np.min(similarities)), + 'std_similarity': float(np.std(similarities)), + 'median_similarity': float(np.median(similarities)) + } + + def get_memory_usage(self): + """Get memory usage information""" + return self.image_embedder.get_memory_usage() + + +def print_results(results: List[Dict], query: str, show_scores: bool = True): + """Pretty print search results""" + print(f"\nTop {len(results)} similar images to: {Path(query).name}") + print(f"Query: {query}") + print("-" * 100) + + for i, result in enumerate(results, 1): + path = result['path'] + if show_scores and 'score' in result: + score = result['score'] + print(f"{i:2d}. {path:<80} score={score:>7.4f}") + else: + print(f"{i:2d}. {path}") + + print("-" * 100) + + +def main(): + parser = argparse.ArgumentParser(description="Image-to-Image Similarity Search") + + # Input/Output paths + parser.add_argument( + "--image-embeddings", + type=str, + default=str(DEFAULT_IMG_EMB_PATH), + help="Path to image embeddings pickle file" + ) + + # Query options + parser.add_argument( + "--query-image", + type=str, + help="Path to query image" + ) + parser.add_argument( + "--query-images-file", + type=str, + help="File with one image path per line" + ) + parser.add_argument( + "--query-dir", + type=str, + help="Directory containing query images" + ) + + # Model parameters + parser.add_argument( + "--vision-model", + type=str, + default=DEFAULT_VISION_MODEL, + help="CLIP vision model to use" + ) + parser.add_argument( + "--embedding-dim", + type=int, + default=4096, + help="Embedding dimension" + ) + parser.add_argument( + "--llava-ckpt", + type=str, + help="Path to LLaVA checkpoint for projection weights" + ) + + # Search parameters + parser.add_argument( + "--top-k", + type=int, + default=DEFAULT_TOP_K, + help="Number of top results to return" + ) + parser.add_argument( + "--include-self", + action="store_true", + help="Include query image in results" + ) + + # Output options + parser.add_argument( + "--output", + type=str, + help="Save results to JSON file" + ) + parser.add_argument( + "--stats", + action="store_true", + help="Show similarity statistics" + ) + + args = parser.parse_args() + + # Initialize searcher + try: + searcher = ImageSimilaritySearcher( + image_embeddings_path=args.image_embeddings, + vision_model_name=args.vision_model, + embedding_dim=args.embedding_dim, + llava_ckpt_path=args.llava_ckpt + ) + except Exception as e: + logger.error(f"Failed to initialize searcher: {e}") + return + + # Show memory usage + mem = searcher.get_memory_usage() + logger.info(f"GPU Memory - Allocated: {mem['allocated']:.2f}GB, Reserved: {mem['reserved']:.2f}GB") + + # Handle different query modes + if args.query_image: + # Single query image + results = searcher.search_by_image_path( + args.query_image, + top_k=args.top_k, + exclude_self=not args.include_self, + return_scores=True + ) + print_results(results, args.query_image) + + if args.stats: + stats = searcher.get_similarity_stats(args.query_image) + print(f"\nSimilarity Statistics:") + for key, value in stats.items(): + print(f" {key}: {value:.4f}") + + elif args.query_images_file: + # Batch queries from file + try: + with open(args.query_images_file, 'r', encoding='utf-8') as f: + query_paths = [line.strip() for line in f if line.strip()] + + batch_results = searcher.batch_search( + query_paths, + top_k=args.top_k, + exclude_self=not args.include_self + ) + + for query_path, results in batch_results.items(): + if results: # Only show if we got results + print_results(results, query_path) + + except Exception as e: + logger.error(f"Failed to process queries file: {e}") + return + + elif args.query_dir: + # Query all images in directory + query_dir = Path(args.query_dir) + if not query_dir.exists(): + logger.error(f"Query directory does not exist: {query_dir}") + return + + # Find images in directory + image_exts = {'.jpg', '.jpeg', '.png', '.bmp', '.gif', '.webp', '.tiff', '.tif'} + query_paths = [] + for ext in image_exts: + query_paths.extend(query_dir.glob(f"*{ext}")) + query_paths.extend(query_dir.glob(f"*{ext.upper()}")) + + query_paths = [str(p) for p in sorted(query_paths)] + + if not query_paths: + logger.error(f"No images found in {query_dir}") + return + + logger.info(f"Found {len(query_paths)} images in {query_dir}") + + batch_results = searcher.batch_search( + query_paths, + top_k=args.top_k, + exclude_self=not args.include_self + ) + + for query_path, results in batch_results.items(): + if results: # Only show if we got results + print_results(results, query_path) + + else: + print("Please specify a query image, query images file, or query directory") + print("Use --help for more information") + return + + # Save results if requested + if args.output and 'batch_results' in locals(): + import json + + try: + with open(args.output, 'w', encoding='utf-8') as f: + json.dump(batch_results, f, indent=2, ensure_ascii=False) + logger.info(f"Results saved to {args.output}") + except Exception as e: + logger.error(f"Failed to save results: {e}") + + +if __name__ == "__main__": + main() diff --git a/starter.py b/starter.py index 004bdae..c82a4c5 100644 --- a/starter.py +++ b/starter.py @@ -1,101 +1,231 @@ import os import argparse import pickle +import logging +from pathlib import Path +from typing import Dict import numpy as np from PIL import Image import torch +from tqdm import tqdm +from atomicwrites import atomic_write from embedder import ImageEmbedder, DEVICE +# Configure logging +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +# Default paths +DEFAULT_OUTPUT_DIR = Path("embeddings") +DEFAULT_EMB_FILE = "image_embeddings.pkl" + def parse_args(): - p = argparse.ArgumentParser( - description="Batch-convert a folder of images into embedding vectors" + parser = argparse.ArgumentParser( + description="Batch-convert images into 4096D embeddings stored in unified dictionary" ) - p.add_argument( + parser.add_argument( "--image-dir", required=True, - help="Path to a folder containing images (jpg/png/bmp/gif)" + help="Path to folder containing images (jpg/png/bmp/gif/webp)" ) - p.add_argument( - "--output-dir", default="processed_images", - help="Where to save image_vectors.npy and index_to_file.pkl" + parser.add_argument( + "--output-dir", default=str(DEFAULT_OUTPUT_DIR), + help="Directory to save embeddings file" ) - p.add_argument( + parser.add_argument( + "--output-file", default=DEFAULT_EMB_FILE, + help="Filename for embeddings dictionary (pkl format)" + ) + parser.add_argument( "--vision-model", default="openai/clip-vit-large-patch14-336", - help="Hugging Face name of the CLIP vision encoder" + help="Hugging Face CLIP vision model name" ) - p.add_argument( - "--proj-dim", type=int, default=5120, - help="Dimensionality of the projection output" + parser.add_argument( + "--embedding-dim", type=int, default=4096, + help="Output embedding dimension" ) - p.add_argument( + parser.add_argument( "--llava-ckpt", default=None, - help="(Optional) full LLaVA checkpoint .bin to load projector weights from" + help="(Optional) LLaVA checkpoint to load projection weights from" ) - p.add_argument( - "--batch-size", type=int, default=64, - help="How many images to encode per GPU/CPU batch" + parser.add_argument( + "--batch-size", type=int, default=32, + help="Batch size for processing images" ) - return p.parse_args() + parser.add_argument( + "--resume", action="store_true", + help="Resume from existing embeddings file (skip already processed images)" + ) + return parser.parse_args() -def find_images(folder): - exts = (".jpg", ".jpeg", ".png", ".bmp", ".gif") - return sorted([ - os.path.join(folder, fname) - for fname in os.listdir(folder) - if fname.lower().endswith(exts) - ]) +def find_images(folder: str) -> list: + """Find all supported image files in folder (non-recursive by default)""" + supported_exts = (".jpg", ".jpeg", ".png", ".bmp", ".gif", ".webp", ".tiff", ".tif") + image_paths = [] + + folder_path = Path(folder) + + # Use a set to avoid duplicates + found_files = set() + + # Search for files with supported extensions (case-insensitive) + for file_path in folder_path.iterdir(): + if file_path.is_file(): + file_ext = file_path.suffix.lower() + if file_ext in supported_exts: + found_files.add(str(file_path.resolve())) + + return sorted(list(found_files)) + + +def load_existing_embeddings(filepath: Path) -> Dict[str, np.ndarray]: + """Load existing embeddings dictionary""" + if not filepath.exists(): + return {} + + try: + with open(filepath, "rb") as f: + embeddings = pickle.load(f) + logger.info(f"Loaded {len(embeddings)} existing embeddings from {filepath}") + return embeddings + except Exception as e: + logger.warning(f"Failed to load existing embeddings: {e}") + return {} + + +def save_embeddings(embeddings: Dict[str, np.ndarray], filepath: Path): + """Safely save embeddings dictionary""" + filepath.parent.mkdir(parents=True, exist_ok=True) + + with atomic_write(filepath, mode='wb', overwrite=True) as f: + pickle.dump(embeddings, f, protocol=pickle.HIGHEST_PROTOCOL) + + logger.info(f"Saved {len(embeddings)} embeddings to {filepath}") def main(): args = parse_args() - os.makedirs(args.output_dir, exist_ok=True) + + # Setup paths + output_dir = Path(args.output_dir) + output_file = output_dir / args.output_file # Discover images + logger.info(f"Scanning for images in {args.image_dir}") image_paths = find_images(args.image_dir) - if not image_paths: - raise RuntimeError(f"No images found in {args.image_dir}") - print(f"Found {len(image_paths)} images in {args.image_dir}") - # Build embedder + if not image_paths: + raise RuntimeError(f"No supported images found in {args.image_dir}") + + logger.info(f"Found {len(image_paths)} images") + + # Load existing embeddings if resuming + embeddings_dict = {} + if args.resume: + embeddings_dict = load_existing_embeddings(output_file) + + # Filter out already processed images + new_image_paths = [p for p in image_paths if p not in embeddings_dict] + logger.info(f"Resuming: {len(embeddings_dict)} already processed, {len(new_image_paths)} remaining") + image_paths = new_image_paths + + if not image_paths: + logger.info("All images already processed!") + return + + # Initialize embedder + logger.info(f"Initializing ImageEmbedder on {DEVICE}") embedder = ImageEmbedder( vision_model_name=args.vision_model, - proj_out_dim=args.proj_dim, + proj_out_dim=args.embedding_dim, llava_ckpt_path=args.llava_ckpt ) - print(f"Using device: {DEVICE}") - # Process in batches - all_embs = [] - index_to_file = {} + # Show memory usage + mem = embedder.get_memory_usage() + logger.info(f"GPU Memory - Allocated: {mem['allocated']:.2f}GB, Reserved: {mem['reserved']:.2f}GB") + + # Process images in batches + processed_count = 0 + failed_count = 0 + + progress_bar = tqdm(total=len(image_paths), desc="Processing images") + for batch_start in range(0, len(image_paths), args.batch_size): - batch_paths = image_paths[batch_start:batch_start + args.batch_size] - # load images - imgs = [Image.open(p).convert("RGB") for p in batch_paths] - # embed - with torch.no_grad(): - embs = embedder.image_to_embedding(imgs) # (B, D) - embs_np = embs.cpu().numpy() - all_embs.append(embs_np) - # record mapping - for i, p in enumerate(batch_paths): - index_to_file[batch_start + i] = p + batch_end = min(batch_start + args.batch_size, len(image_paths)) + batch_paths = image_paths[batch_start:batch_end] - print(f" • Processed {batch_start + len(batch_paths)}/{len(image_paths)} images") + # Load batch images + batch_images = [] + valid_paths = [] - # Stack and save - vectors = np.vstack(all_embs) # shape (N, D) - vec_file = os.path.join(args.output_dir, "image_vectors.npy") - map_file = os.path.join(args.output_dir, "index_to_file.pkl") + for img_path in batch_paths: + try: + img = Image.open(img_path).convert("RGB") + batch_images.append(img) + valid_paths.append(img_path) + except Exception as e: + logger.warning(f"Failed to load {img_path}: {e}") + failed_count += 1 + progress_bar.update(1) + continue - np.save(vec_file, vectors) - with open(map_file, "wb") as f: - pickle.dump(index_to_file, f) + if not batch_images: + continue - print(f"\nSaved {vectors.shape[0]}×{vectors.shape[1]} vectors to\n {vec_file}") - print(f"Saved index→file mapping to\n {map_file}") + # Generate embeddings for batch + try: + with torch.no_grad(): + if len(batch_images) == 1: + embeddings = embedder.image_to_embedding(batch_images[0]).cpu().numpy() + embeddings = embeddings.reshape(1, -1) # Make it (1, D) + else: + embeddings = embedder.image_to_embedding(batch_images).cpu().numpy() # (B, D) + + # Store embeddings in dictionary + for img_path, embedding in zip(valid_paths, embeddings): + embeddings_dict[img_path] = embedding + processed_count += 1 + + progress_bar.update(len(valid_paths)) + + # Periodic save every 10 batches + if (batch_start // args.batch_size) % 10 == 0: + save_embeddings(embeddings_dict, output_file) + + except Exception as e: + logger.error(f"Failed to process batch starting at {batch_start}: {e}") + failed_count += len(batch_paths) + progress_bar.update(len(batch_paths)) + continue + + progress_bar.close() + + # Final save + save_embeddings(embeddings_dict, output_file) + + # Summary + total_processed = len(embeddings_dict) + logger.info(f"\nProcessing complete!") + logger.info(f"Total embeddings: {total_processed}") + logger.info(f"Newly processed: {processed_count}") + logger.info(f"Failed: {failed_count}") + logger.info(f"Embedding dimension: {args.embedding_dim}") + logger.info(f"Output file: {output_file}") + + # Verify a sample embedding + if embeddings_dict: + sample_path = next(iter(embeddings_dict)) + sample_emb = embeddings_dict[sample_path] + logger.info(f"Sample embedding shape: {sample_emb.shape}") + logger.info(f"Sample embedding norm: {np.linalg.norm(sample_emb):.4f}") + + # Final memory usage + mem = embedder.get_memory_usage() + logger.info(f"Final GPU Memory - Allocated: {mem['allocated']:.2f}GB, Reserved: {mem['reserved']:.2f}GB") if __name__ == "__main__": diff --git a/text_embedder.py b/text_embedder.py new file mode 100644 index 0000000..4f1d12b --- /dev/null +++ b/text_embedder.py @@ -0,0 +1,306 @@ +import pickle +import hashlib +import logging +from pathlib import Path +from typing import Dict, List, Union + +import torch +import torch.nn as nn +import numpy as np +from transformers import CLIPTokenizer, CLIPTextModel +from cachetools import LRUCache +from tqdm import tqdm +from atomicwrites import atomic_write + +# Configure logging +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +# Device & defaults +DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") +# Use CLIP which is specifically designed for text-image alignment +DEFAULT_MODEL = "openai/clip-vit-base-patch32" +OUTPUT_DIM = 4096 # Match image embeddings dimension +CHUNK_BATCH_SIZE = 16 +ARTICLE_CACHE_SIZE = 1024 + +# Persistence paths +EMBEDDINGS_DIR = Path(__file__).parent / "embeddings" +TEXT_EMB_PATH = EMBEDDINGS_DIR / "text_embeddings.pkl" + + +class CLIPTextEmbedder: + """ + Text embedder using CLIP's text encoder, optimized for text-to-image + retrieval with 4096-dimensional output to match image embeddings. + """ + + def __init__( + self, + model_name: str = DEFAULT_MODEL, + output_dim: int = OUTPUT_DIM, + chunk_batch_size: int = CHUNK_BATCH_SIZE, + ): + logger.info(f"Initializing CLIPTextEmbedder with {model_name}") + + # Load CLIP text encoder + self.tokenizer = CLIPTokenizer.from_pretrained(model_name) + self.text_encoder = CLIPTextModel.from_pretrained(model_name).to(DEVICE).eval() + + # Freeze the text encoder + for p in self.text_encoder.parameters(): + p.requires_grad = False + + # Get CLIP's text embedding dimension (usually 512) + clip_dim = self.text_encoder.config.hidden_size + + # Enhanced projector to match 4096-D image embeddings + self.projector = nn.Sequential( + nn.Linear(clip_dim, output_dim // 2), + nn.LayerNorm(output_dim // 2), + nn.ReLU(), + nn.Dropout(0.1), + nn.Linear(output_dim // 2, output_dim), + nn.LayerNorm(output_dim), + nn.ReLU(), + nn.Linear(output_dim, output_dim), + ).to(DEVICE) + + # Initialize projector with Xavier initialization + for layer in self.projector: + if isinstance(layer, nn.Linear): + nn.init.xavier_uniform_(layer.weight) + nn.init.zeros_(layer.bias) + + self.article_cache = LRUCache(maxsize=ARTICLE_CACHE_SIZE) + self.chunk_batch_size = chunk_batch_size + + # Ensure embeddings dir + EMBEDDINGS_DIR.mkdir(parents=True, exist_ok=True) + + logger.info(f"CLIPTextEmbedder ready: {clip_dim}D -> {output_dim}D on {DEVICE}") + + def text_to_embedding(self, text: Union[str, List[str]]) -> torch.Tensor: + """ + Embed text using CLIP's text encoder + projector + + Args: + text: Single string or list of strings + + Returns: + torch.Tensor: Normalized embeddings of shape (D,) for single text + or (B, D) for batch of texts + """ + # Handle single text vs batch + if isinstance(text, str): + texts = [text] + single_text = True + else: + texts = text + single_text = False + + # Check cache for single text + if single_text and texts[0] in self.article_cache: + return self.article_cache[texts[0]] + + # Tokenize and encode with CLIP + inputs = self.tokenizer( + texts, + return_tensors="pt", + truncation=True, + padding=True, + max_length=77 # CLIP's max length + ).to(DEVICE) + + with torch.no_grad(): + # Get CLIP text features + text_outputs = self.text_encoder(**inputs) + # Use pooled output (CLS token equivalent) + clip_features = text_outputs.pooler_output # [batch_size, clip_dim] + + # Normalize CLIP features + clip_features = nn.functional.normalize(clip_features, p=2, dim=1) + + # Project to target dimension + projected = self.projector(clip_features) + # Normalize final output + projected = nn.functional.normalize(projected, p=2, dim=1) + + # Cache single text result + if single_text: + result = projected.squeeze(0).cpu() + self.article_cache[texts[0]] = result + return result + else: + return projected.cpu() # Return batch + + def embed_texts_from_file( + self, + text_file: Union[str, Path], + save_path: Union[str, Path] = TEXT_EMB_PATH + ) -> Dict[str, np.ndarray]: + """ + Read newline-separated texts and embed them + """ + save_path = Path(save_path) + save_path.parent.mkdir(parents=True, exist_ok=True) + + # Load existing embeddings + try: + with open(save_path, "rb") as fp: + mapping: Dict[str, np.ndarray] = pickle.load(fp) + logger.info(f"Loaded {len(mapping)} existing text embeddings") + except FileNotFoundError: + mapping = {} + except Exception as e: + logger.warning(f"Failed loading {save_path}: {e}") + mapping = {} + + # Read text file + try: + with open(text_file, "r", encoding="utf-8") as fp: + lines = [ln.strip() for ln in fp if ln.strip()] + except Exception as e: + raise RuntimeError(f"Error reading {text_file}: {e}") + + logger.info(f"Processing {len(lines)} texts from {text_file}") + + # Process in batches + new_texts = [ln for ln in lines if ln not in mapping] + logger.info(f"Embedding {len(new_texts)} new texts") + + if not new_texts: + logger.info("All texts already embedded!") + return mapping + + # Process in batches + for i in tqdm(range(0, len(new_texts), self.chunk_batch_size), desc="Embedding texts"): + batch_texts = new_texts[i:i + self.chunk_batch_size] + + try: + # Get batch embeddings + batch_embeddings = self.text_to_embedding(batch_texts) + + # Store in mapping + for text, embedding in zip(batch_texts, batch_embeddings): + mapping[text] = embedding.numpy() + + except Exception as e: + logger.error(f"Embedding failed for batch starting at {i}: {e}") + # Process individually as fallback + for text in batch_texts: + try: + vec = self.text_to_embedding(text) + mapping[text] = vec.numpy() + except Exception as e2: + logger.error(f"Individual embedding failed for text: {e2}") + + # Save updated mappings + with atomic_write(save_path, mode='wb', overwrite=True) as f: + pickle.dump(mapping, f, protocol=pickle.HIGHEST_PROTOCOL) + + logger.info(f"Saved {len(mapping)} text embeddings to {save_path}") + return mapping + + def embed_single_text(self, text: str) -> np.ndarray: + """Embed a single text and return as numpy array""" + embedding = self.text_to_embedding(text) + return embedding.numpy() + + def clear_cache(self): + """Clear the text cache""" + self.article_cache.clear() + logger.info("Text cache cleared") + + def get_memory_usage(self): + """Get current GPU memory usage if available.""" + if torch.cuda.is_available(): + return { + 'allocated': torch.cuda.memory_allocated() / 1024 ** 3, # GB + 'reserved': torch.cuda.memory_reserved() / 1024 ** 3, # GB + } + return {'allocated': 0, 'reserved': 0} + + +# Alias for compatibility +TextEmbedder = CLIPTextEmbedder + + +def main(): + import argparse + + parser = argparse.ArgumentParser("CLIP Text Embedder for Image Search") + parser.add_argument("--text", type=str, help="Single text to embed") + parser.add_argument("--file", type=str, help="Path to newline-delimited texts") + parser.add_argument( + "--out", type=str, default=str(TEXT_EMB_PATH), + help="Pickle path for saving embeddings" + ) + parser.add_argument( + "--batch-size", type=int, default=CHUNK_BATCH_SIZE, + help="Batch size for processing" + ) + parser.add_argument( + "--model", type=str, default=DEFAULT_MODEL, + help="CLIP model name to use" + ) + parser.add_argument( + "--dim", type=int, default=OUTPUT_DIM, + help="Output embedding dimension" + ) + args = parser.parse_args() + + embedder = CLIPTextEmbedder( + model_name=args.model, + output_dim=args.dim, + chunk_batch_size=args.batch_size, + ) + + # Show memory usage + mem = embedder.get_memory_usage() + logger.info(f"GPU Memory - Allocated: {mem['allocated']:.2f}GB, Reserved: {mem['reserved']:.2f}GB") + + # Single-text mode + if args.text: + vec = embedder.embed_single_text(args.text) + print(f"Text: {args.text}") + print(f"Shape: {vec.shape}") + print(f"Norm: {np.linalg.norm(vec):.4f}") + print(f"Sample values: {vec[:10]}") + + # Save to file + save_path = Path(args.out) + save_path.parent.mkdir(parents=True, exist_ok=True) + try: + with open(save_path, "rb") as fp: + mapping = pickle.load(fp) + except (FileNotFoundError, pickle.PickleError): + mapping = {} + + mapping[args.text] = vec + with atomic_write(save_path, mode='wb', overwrite=True) as f: + pickle.dump(mapping, f, protocol=pickle.HIGHEST_PROTOCOL) + logger.info(f"Saved text embedding to {save_path}") + + # Batch-file mode + if args.file: + embedder.embed_texts_from_file(args.file, save_path=args.out) + logger.info(f"Completed batch embedding from {args.file}") + + if not args.text and not args.file: + # Demo mode + demo_texts = [ + "a cat sleeping on a bed", + "a dog running in the park", + "beautiful sunset over mountains", + "red sports car on highway" + ] + + print("Demo: Embedding sample texts...") + for text in demo_texts: + vec = embedder.embed_single_text(text) + print(f"'{text}' -> shape {vec.shape}, norm {np.linalg.norm(vec):.4f}") + + +if __name__ == "__main__": + main() diff --git a/text_search_starter.py b/text_search_starter.py new file mode 100644 index 0000000..df8bb43 --- /dev/null +++ b/text_search_starter.py @@ -0,0 +1,356 @@ +import argparse +import pickle +import logging +from pathlib import Path +from typing import Dict, List, Optional + +import numpy as np +import torch + +from text_embedder import CLIPTextEmbedder + +# Configure logging +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +# Default paths +DEFAULT_IMG_EMB_PATH = Path("embeddings/image_embeddings.pkl") +DEFAULT_TEXT_MODEL = "openai/clip-vit-base-patch32" +DEFAULT_TOP_K = 10 + + +class TextToImageSearcher: + """ + Text-to-image search using CLIP embeddings + """ + + def __init__( + self, + image_embeddings_path: Path, + text_model_name: str = DEFAULT_TEXT_MODEL, + embedding_dim: int = 4096 + ): + """ + Initialize the Text-to-Image searcher + + Args: + image_embeddings_path: Path to the pickle file containing image embeddings + text_model_name: Name of the CLIP text model to use + embedding_dim: Dimension of the embeddings (should match image embeddings) + """ + self.image_embeddings_path = Path(image_embeddings_path) + self.embedding_dim = embedding_dim + + # Load image embeddings + logger.info(f"Loading image embeddings from {self.image_embeddings_path}") + self.image_data = self._load_image_embeddings() + + # Initialize text embedder + logger.info(f"Initializing text embedder: {text_model_name}") + self.text_embedder = CLIPTextEmbedder( + model_name=text_model_name, + output_dim=embedding_dim + ) + + logger.info(f"TextToImageSearcher ready with {len(self.image_data)} images") + + def _load_image_embeddings(self) -> Dict[str, np.ndarray]: + """Load image embeddings from pickle file""" + try: + with open(self.image_embeddings_path, "rb") as f: + data = pickle.load(f) + + if not isinstance(data, dict): + raise ValueError("Image embeddings file should contain a dictionary") + + # Verify embedding dimensions + sample_key = next(iter(data)) + sample_emb = data[sample_key] + if sample_emb.shape[-1] != self.embedding_dim: + raise ValueError(f"Expected {self.embedding_dim}D embeddings, got {sample_emb.shape[-1]}D") + + logger.info(f"Loaded {len(data)} image embeddings of dimension {sample_emb.shape[-1]}") + return data + + except FileNotFoundError: + raise FileNotFoundError(f"Image embeddings file not found: {self.image_embeddings_path}") + except Exception as e: + raise RuntimeError(f"Failed to load image embeddings: {e}") + + def search( + self, + query_text: str, + top_k: int = DEFAULT_TOP_K, + return_scores: bool = False + ) -> List[Dict]: + """ + Search for images matching the query text + + Args: + query_text: Text description to search for + top_k: Number of top results to return + return_scores: Whether to include similarity scores + + Returns: + List of dictionaries with 'path' and optionally 'score' keys + """ + logger.info(f"Searching for: '{query_text}'") + + # Embed query text + query_embedding = self.text_embedder.embed_single_text(query_text) + + # Prepare image data for comparison + image_paths = list(self.image_data.keys()) + image_embeddings = np.stack(list(self.image_data.values())) # (N, D) + + # Compute cosine similarities (both embeddings are normalized) + similarities = np.dot(image_embeddings, query_embedding) + + # Get top-k indices + top_k_indices = np.argsort(-similarities)[:top_k] + + # Prepare results + results = [] + for idx in top_k_indices: + result = {'path': image_paths[idx]} + if return_scores: + result['score'] = float(similarities[idx]) + results.append(result) + + return results + + def batch_search( + self, + query_texts: List[str], + top_k: int = DEFAULT_TOP_K + ) -> Dict[str, List[Dict]]: + """ + Search for multiple queries at once + + Args: + query_texts: List of text descriptions + top_k: Number of results per query + + Returns: + Dictionary mapping query text to list of results + """ + logger.info(f"Batch searching {len(query_texts)} queries") + + results = {} + for query in query_texts: + results[query] = self.search(query, top_k=top_k, return_scores=True) + + return results + + def get_search_stats(self, query_text: str) -> Dict: + """ + Get statistics about search results + + Args: + query_text: Query to analyze + + Returns: + Dictionary with similarity statistics + """ + query_embedding = self.text_embedder.embed_single_text(query_text) + image_embeddings = np.stack(list(self.image_data.values())) + similarities = np.dot(image_embeddings, query_embedding) + + return { + 'mean_similarity': float(np.mean(similarities)), + 'max_similarity': float(np.max(similarities)), + 'min_similarity': float(np.min(similarities)), + 'std_similarity': float(np.std(similarities)), + 'median_similarity': float(np.median(similarities)) + } + + def get_memory_usage(self): + """Get memory usage information""" + return self.text_embedder.get_memory_usage() + + def clear_cache(self): + """Clear text embedding cache""" + self.text_embedder.clear_cache() + + +def print_results(results: List[Dict], query: str, show_scores: bool = True): + """Pretty print search results""" + print(f"\nTop {len(results)} results for: '{query}'") + print("-" * 80) + + for i, result in enumerate(results, 1): + path = result['path'] + if show_scores and 'score' in result: + score = result['score'] + print(f"{i:2d}. {path:<60} score={score:>7.4f}") + else: + print(f"{i:2d}. {path}") + + print("-" * 80) + + +def main(): + parser = argparse.ArgumentParser(description="Text-to-Image Search using CLIP embeddings") + + # Input/Output paths + parser.add_argument( + "--image-embeddings", + type=str, + default=str(DEFAULT_IMG_EMB_PATH), + help="Path to image embeddings pickle file" + ) + + # Query options + parser.add_argument( + "--query", + type=str, + help="Single text query to search for" + ) + parser.add_argument( + "--queries-file", + type=str, + help="File with one query per line" + ) + parser.add_argument( + "--interactive", + action="store_true", + help="Interactive mode for multiple queries" + ) + + # Search parameters + parser.add_argument( + "--top-k", + type=int, + default=DEFAULT_TOP_K, + help="Number of top results to return" + ) + parser.add_argument( + "--text-model", + type=str, + default=DEFAULT_TEXT_MODEL, + help="CLIP text model to use" + ) + parser.add_argument( + "--embedding-dim", + type=int, + default=4096, + help="Embedding dimension" + ) + + # Output options + parser.add_argument( + "--output", + type=str, + help="Save results to JSON file" + ) + parser.add_argument( + "--stats", + action="store_true", + help="Show similarity statistics" + ) + + args = parser.parse_args() + + # Initialize searcher + try: + searcher = TextToImageSearcher( + image_embeddings_path=args.image_embeddings, + text_model_name=args.text_model, + embedding_dim=args.embedding_dim + ) + except Exception as e: + logger.error(f"Failed to initialize searcher: {e}") + return + + # Show memory usage + mem = searcher.get_memory_usage() + logger.info(f"GPU Memory - Allocated: {mem['allocated']:.2f}GB, Reserved: {mem['reserved']:.2f}GB") + + # Handle different query modes + if args.query: + # Single query + results = searcher.search(args.query, top_k=args.top_k, return_scores=True) + print_results(results, args.query) + + if args.stats: + stats = searcher.get_search_stats(args.query) + print(f"\nSimilarity Statistics:") + for key, value in stats.items(): + print(f" {key}: {value:.4f}") + + elif args.queries_file: + # Batch queries from file + try: + with open(args.queries_file, 'r', encoding='utf-8') as f: + queries = [line.strip() for line in f if line.strip()] + + batch_results = searcher.batch_search(queries, top_k=args.top_k) + + for query, results in batch_results.items(): + print_results(results, query) + + except Exception as e: + logger.error(f"Failed to process queries file: {e}") + return + + elif args.interactive: + # Interactive mode + print("\nInteractive Text-to-Image Search") + print("Enter text queries (or 'quit' to exit):") + print("-" * 40) + + while True: + try: + query = input("\nQuery: ").strip() + if query.lower() in ['quit', 'exit', 'q']: + break + + if not query: + continue + + results = searcher.search(query, top_k=args.top_k, return_scores=True) + print_results(results, query) + + if args.stats: + stats = searcher.get_search_stats(query) + print(f"\nStats: mean={stats['mean_similarity']:.3f}, " + f"max={stats['max_similarity']:.3f}, " + f"std={stats['std_similarity']:.3f}") + + except KeyboardInterrupt: + print("\nExiting...") + break + except Exception as e: + logger.error(f"Error processing query: {e}") + + else: + # Demo mode with sample queries + demo_queries = [ + "a cat sleeping on a bed", + ] + + print("Demo mode: Running sample queries...") + batch_results = searcher.batch_search(demo_queries, top_k=5) + + for query, results in batch_results.items(): + print_results(results, query) + + # Save results if requested + if args.output and 'batch_results' in locals(): + import json + + # Convert results to JSON-serializable format + json_results = {} + for query, results in batch_results.items(): + json_results[query] = results # Results are already dict format + + try: + with open(args.output, 'w', encoding='utf-8') as f: + json.dump(json_results, f, indent=2, ensure_ascii=False) + logger.info(f"Results saved to {args.output}") + except Exception as e: + logger.error(f"Failed to save results: {e}") + + +if __name__ == "__main__": + main()