103 lines
3.0 KiB
Python
103 lines
3.0 KiB
Python
import os
|
||
import argparse
|
||
import pickle
|
||
|
||
import numpy as np
|
||
from PIL import Image
|
||
import torch
|
||
|
||
from embedder import ImageEmbedder, DEVICE
|
||
|
||
|
||
def parse_args():
|
||
p = argparse.ArgumentParser(
|
||
description="Batch-convert a folder of images into embedding vectors"
|
||
)
|
||
p.add_argument(
|
||
"--image-dir", required=True,
|
||
help="Path to a folder containing images (jpg/png/bmp/gif)"
|
||
)
|
||
p.add_argument(
|
||
"--output-dir", default="processed_images",
|
||
help="Where to save image_vectors.npy and index_to_file.pkl"
|
||
)
|
||
p.add_argument(
|
||
"--vision-model", default="openai/clip-vit-large-patch14-336",
|
||
help="Hugging Face name of the CLIP vision encoder"
|
||
)
|
||
p.add_argument(
|
||
"--proj-dim", type=int, default=5120,
|
||
help="Dimensionality of the projection output"
|
||
)
|
||
p.add_argument(
|
||
"--llava-ckpt", default=None,
|
||
help="(Optional) full LLaVA checkpoint .bin to load projector weights from"
|
||
)
|
||
p.add_argument(
|
||
"--batch-size", type=int, default=64,
|
||
help="How many images to encode per GPU/CPU batch"
|
||
)
|
||
return p.parse_args()
|
||
|
||
|
||
def find_images(folder):
|
||
exts = (".jpg", ".jpeg", ".png", ".bmp", ".gif")
|
||
return sorted([
|
||
os.path.join(folder, fname)
|
||
for fname in os.listdir(folder)
|
||
if fname.lower().endswith(exts)
|
||
])
|
||
|
||
|
||
def main():
|
||
args = parse_args()
|
||
os.makedirs(args.output_dir, exist_ok=True)
|
||
|
||
# Discover images
|
||
image_paths = find_images(args.image_dir)
|
||
if not image_paths:
|
||
raise RuntimeError(f"No images found in {args.image_dir}")
|
||
print(f"Found {len(image_paths)} images in {args.image_dir}")
|
||
|
||
# Build embedder
|
||
embedder = ImageEmbedder(
|
||
vision_model_name=args.vision_model,
|
||
proj_out_dim=args.proj_dim,
|
||
llava_ckpt_path=args.llava_ckpt
|
||
)
|
||
print(f"Using device: {DEVICE}")
|
||
|
||
# Process in batches
|
||
all_embs = []
|
||
index_to_file = {}
|
||
for batch_start in range(0, len(image_paths), args.batch_size):
|
||
batch_paths = image_paths[batch_start:batch_start + args.batch_size]
|
||
# load images
|
||
imgs = [Image.open(p).convert("RGB") for p in batch_paths]
|
||
# embed
|
||
with torch.no_grad():
|
||
embs = embedder.image_to_embedding(imgs) # (B, D)
|
||
embs_np = embs.cpu().numpy()
|
||
all_embs.append(embs_np)
|
||
# record mapping
|
||
for i, p in enumerate(batch_paths):
|
||
index_to_file[batch_start + i] = p
|
||
|
||
print(f" • Processed {batch_start + len(batch_paths)}/{len(image_paths)} images")
|
||
|
||
# Stack and save
|
||
vectors = np.vstack(all_embs) # shape (N, D)
|
||
vec_file = os.path.join(args.output_dir, "image_vectors.npy")
|
||
map_file = os.path.join(args.output_dir, "index_to_file.pkl")
|
||
|
||
np.save(vec_file, vectors)
|
||
with open(map_file, "wb") as f:
|
||
pickle.dump(index_to_file, f)
|
||
|
||
print(f"\nSaved {vectors.shape[0]}×{vectors.shape[1]} vectors to\n {vec_file}")
|
||
print(f"Saved index→file mapping to\n {map_file}")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|