import os import argparse import pickle import numpy as np from PIL import Image import torch from embedder import ImageEmbedder, DEVICE def parse_args(): p = argparse.ArgumentParser( description="Batch-convert a folder of images into embedding vectors" ) p.add_argument( "--image-dir", required=True, help="Path to a folder containing images (jpg/png/bmp/gif)" ) p.add_argument( "--output-dir", default="processed_images", help="Where to save image_vectors.npy and index_to_file.pkl" ) p.add_argument( "--vision-model", default="openai/clip-vit-large-patch14-336", help="Hugging Face name of the CLIP vision encoder" ) p.add_argument( "--proj-dim", type=int, default=5120, help="Dimensionality of the projection output" ) p.add_argument( "--llava-ckpt", default=None, help="(Optional) full LLaVA checkpoint .bin to load projector weights from" ) p.add_argument( "--batch-size", type=int, default=64, help="How many images to encode per GPU/CPU batch" ) return p.parse_args() def find_images(folder): exts = (".jpg", ".jpeg", ".png", ".bmp", ".gif") return sorted([ os.path.join(folder, fname) for fname in os.listdir(folder) if fname.lower().endswith(exts) ]) def main(): args = parse_args() os.makedirs(args.output_dir, exist_ok=True) # Discover images image_paths = find_images(args.image_dir) if not image_paths: raise RuntimeError(f"No images found in {args.image_dir}") print(f"Found {len(image_paths)} images in {args.image_dir}") # Build embedder embedder = ImageEmbedder( vision_model_name=args.vision_model, proj_out_dim=args.proj_dim, llava_ckpt_path=args.llava_ckpt ) print(f"Using device: {DEVICE}") # Process in batches all_embs = [] index_to_file = {} for batch_start in range(0, len(image_paths), args.batch_size): batch_paths = image_paths[batch_start:batch_start + args.batch_size] # load images imgs = [Image.open(p).convert("RGB") for p in batch_paths] # embed with torch.no_grad(): embs = embedder.image_to_embedding(imgs) # (B, D) embs_np = embs.cpu().numpy() all_embs.append(embs_np) # record mapping for i, p in enumerate(batch_paths): index_to_file[batch_start + i] = p print(f" • Processed {batch_start + len(batch_paths)}/{len(image_paths)} images") # Stack and save vectors = np.vstack(all_embs) # shape (N, D) vec_file = os.path.join(args.output_dir, "image_vectors.npy") map_file = os.path.join(args.output_dir, "index_to_file.pkl") np.save(vec_file, vectors) with open(map_file, "wb") as f: pickle.dump(index_to_file, f) print(f"\nSaved {vectors.shape[0]}×{vectors.shape[1]} vectors to\n {vec_file}") print(f"Saved index→file mapping to\n {map_file}") if __name__ == "__main__": main()