import pickle import logging from pathlib import Path from typing import Dict, List, Union import numpy as np import torch import torch.nn as nn from PIL import Image, UnidentifiedImageError from transformers import CLIPImageProcessor, CLIPVisionModel from tqdm import tqdm from atomicwrites import atomic_write # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Device & defaults DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") DEFAULT_VISION_MODEL = "openai/clip-vit-large-patch14-336" PROJECTOR_WEIGHTS_PATH = ( Path(__file__).parent.parent / "models" / "llava-v1.5-mlp2x-336px-vicuna-7b" / "mm_projector.bin" ) OUTPUT_DIM = 4096 DEFAULT_BATCH_SIZE = 32 # Default output path EMBED_DIR = Path(__file__).parent.parent / "embeddings" IMAGE_EMB_PATH = EMBED_DIR / "image_embeddings.pkl" class ImageEmbedder: """ Frozen CLIP vision encoder + pretrained 2-layer MLP projector. Uses PATCH-LEVEL features only (mean of patch embeddings) following LLaVA's configuration: mm_vision_select_feature="patch". Provides single-image and folder-wide embedding, with options to print only or append to a persistent pickle store via atomic writes. """ def __init__( self, vision_model_name: str = DEFAULT_VISION_MODEL, projector_weights_path: Union[str, Path] = PROJECTOR_WEIGHTS_PATH, output_dim: int = OUTPUT_DIM, ): # 1) Setup CLIP processor + model self.processor = CLIPImageProcessor.from_pretrained(vision_model_name) self.vision_model = ( CLIPVisionModel .from_pretrained(vision_model_name, output_hidden_states=True) .to(DEVICE) .eval() ) for p in self.vision_model.parameters(): p.requires_grad = False # 2) Build MLP projector - PATCH ONLY (1024 input, not 2048) # Following LLaVA config: mm_vision_select_feature="patch", mm_hidden_size=1024 hidden_dim = self.vision_model.config.hidden_size # 1024, not 1024*2 self.projector = nn.Sequential( nn.Linear(hidden_dim, output_dim), # 1024 -> 4096 nn.GELU(), nn.Linear(output_dim, output_dim), # 4096 -> 4096 ).to(DEVICE) # 3) Load projector weights ckpt = torch.load(projector_weights_path, map_location=DEVICE) state = ckpt.get("projector_state_dict", ckpt) # Strip prefix clean_state = { (k[len("model.mm_projector."):] if k.startswith("model.mm_projector.") else k): v for k, v in state.items() } # Load with warning on mismatch missing, unexpected = self.projector.load_state_dict(clean_state, strict=False) if missing or unexpected: logger.info(f"Loaded projector weights with missing={missing}, unexpected={unexpected}") def image_to_embedding(self, image: Image.Image) -> torch.Tensor: """ Embed a single PIL.Image → (output_dim,) L2-normalized vector. Uses PATCH-LEVEL features only (mean of patches, no CLS token). """ try: inputs = self.processor(images=image, return_tensors="pt") inputs = {k: v.to(DEVICE) for k, v in inputs.items()} except Exception as e: logger.exception(f"Preprocessing failed: {e}") raise with torch.no_grad(): last_hidden = self.vision_model(**inputs).last_hidden_state # PATCH-ONLY: Use mean of patch embeddings, drop CLS token # last_hidden shape: (1, 1+num_patches, 1024) # [:, 0, :] = CLS token (skip) # [:, 1:, :] = patch tokens (use mean) patch_mean = last_hidden[:, 1:, :].mean(dim=1) # (1, 1024) combo = nn.functional.normalize(patch_mean, p=2, dim=1) # L2 normalize proj = self.projector(combo) # (1, 4096) emb = nn.functional.normalize(proj, p=2, dim=1) # L2 normalize output return emb.squeeze(0) # → (OUTPUT_DIM,) def embed_folder( self, folder_path: Union[str, Path], save_path: Union[str, Path] = IMAGE_EMB_PATH, extensions: List[str] = None, batch_size: int = DEFAULT_BATCH_SIZE, ) -> Dict[str, np.ndarray]: """ Embed all images in a folder (recursively), skipping already-embedded and corrupted files, saving after each batch. Uses PATCH-LEVEL features only. """ if extensions is None: extensions = [".jpg", ".jpeg", ".png", ".webp", ".bmp", ".tiff", ".gif"] folder = Path(folder_path) save_path = Path(save_path) save_path.parent.mkdir(parents=True, exist_ok=True) # Load existing embeddings try: with open(save_path, "rb") as fp: mapping = pickle.load(fp) except FileNotFoundError: mapping = {} except Exception as e: logger.warning(f"Could not load embeddings: {e}") mapping = {} # Collect files to embed all_files = [f for f in folder.rglob("*") if f.suffix.lower() in extensions] to_process = [f for f in all_files if str(f) not in mapping] if not to_process: logger.info("No new images to embed.") return mapping logger.info(f"Embedding {len(to_process)} images (batch size {batch_size})") for i in tqdm(range(0, len(to_process), batch_size), desc="Images"): batch = to_process[i: i + batch_size] imgs, paths = [], [] for f in batch: try: imgs.append(Image.open(f).convert("RGB")) paths.append(f) except UnidentifiedImageError: logger.warning(f"Skipping corrupted: {f}") except Exception as e: logger.warning(f"Error loading {f}: {e}") if not imgs: continue # Embed batch inputs = self.processor(images=imgs, return_tensors="pt") inputs = {k: v.to(DEVICE) for k, v in inputs.items()} with torch.no_grad(): last_hidden = self.vision_model(**inputs).last_hidden_state # PATCH-ONLY: Use mean of patches, drop CLS token patch_means = last_hidden[:, 1:, :].mean(dim=1) # (batch_size, 1024) combo = nn.functional.normalize(patch_means, p=2, dim=1) proj = self.projector(combo) # (batch_size, 4096) embs = nn.functional.normalize(proj, p=2, dim=1).cpu().detach().numpy() # Update and save for pth, emb in zip(paths, embs): mapping[str(pth)] = emb with atomic_write(save_path, mode="wb", overwrite=True) as fp: pickle.dump(mapping, fp, protocol=pickle.HIGHEST_PROTOCOL) logger.info(f"Done. Total embeddings: {len(mapping)}") return mapping def embed_image( self, image_path: Union[str, Path], save_path: Union[str, Path] = IMAGE_EMB_PATH, no_save: bool = False, ) -> torch.Tensor: """ Embed one image using PATCH-LEVEL features only: - prints the vector - unless no_save=True, also appends to the pickle at save_path Returns the torch.Tensor. """ img_path = Path(image_path) try: img = Image.open(img_path).convert("RGB") vec = self.image_to_embedding(img) except Exception as e: logger.error(f"Failed embedding {img_path}: {e}") raise # Print out the vector vec_np = vec.cpu().detach().numpy() print(f"{img_path} → shape={vec_np.shape}") print(vec_np) # Append to pickle unless disabled if not no_save: save_path = Path(save_path) save_path.parent.mkdir(parents=True, exist_ok=True) try: with open(save_path, "rb") as fp: mapping = pickle.load(fp) except (FileNotFoundError, pickle.PickleError): mapping = {} mapping[str(img_path)] = vec_np with atomic_write(save_path, overwrite=True) as fp: fp.write(pickle.dumps(mapping, protocol=pickle.HIGHEST_PROTOCOL)) logger.info(f"Saved embedding for {img_path} to {save_path}") return vec def main(): import argparse parser = argparse.ArgumentParser(__doc__) parser.add_argument("--image", type=str, help="Path to a single image to embed") parser.add_argument("--folder", type=str, help="Path to a folder of images") parser.add_argument("--out", type=str, default=str(IMAGE_EMB_PATH), help="Pickle to save embeddings (default: embeddings/image_embeddings.pkl)") parser.add_argument("--batch-size", type=int, default=DEFAULT_BATCH_SIZE, help="Batch size for folder embedding") parser.add_argument("--no-save", action="store_true", help="When embedding a single image, do not save to the pickle") args = parser.parse_args() embedder = ImageEmbedder() if args.image: embedder.embed_image( image_path=args.image, save_path=args.out, no_save=args.no_save ) elif args.folder: embedder.embed_folder( folder_path=args.folder, save_path=args.out, batch_size=args.batch_size ) else: parser.error("Please specify --image or --folder") if __name__ == "__main__": main()