Img2Vec/starter.py
2025-06-13 15:00:33 +08:00

103 lines
3.0 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import argparse
import pickle
import numpy as np
from PIL import Image
import torch
from embedder import ImageEmbedder, DEVICE
def parse_args():
p = argparse.ArgumentParser(
description="Batch-convert a folder of images into embedding vectors"
)
p.add_argument(
"--image-dir", required=True,
help="Path to a folder containing images (jpg/png/bmp/gif)"
)
p.add_argument(
"--output-dir", default="processed_images",
help="Where to save image_vectors.npy and index_to_file.pkl"
)
p.add_argument(
"--vision-model", default="openai/clip-vit-large-patch14-336",
help="Hugging Face name of the CLIP vision encoder"
)
p.add_argument(
"--proj-dim", type=int, default=5120,
help="Dimensionality of the projection output"
)
p.add_argument(
"--llava-ckpt", default=None,
help="(Optional) full LLaVA checkpoint .bin to load projector weights from"
)
p.add_argument(
"--batch-size", type=int, default=64,
help="How many images to encode per GPU/CPU batch"
)
return p.parse_args()
def find_images(folder):
exts = (".jpg", ".jpeg", ".png", ".bmp", ".gif")
return sorted([
os.path.join(folder, fname)
for fname in os.listdir(folder)
if fname.lower().endswith(exts)
])
def main():
args = parse_args()
os.makedirs(args.output_dir, exist_ok=True)
# Discover images
image_paths = find_images(args.image_dir)
if not image_paths:
raise RuntimeError(f"No images found in {args.image_dir}")
print(f"Found {len(image_paths)} images in {args.image_dir}")
# Build embedder
embedder = ImageEmbedder(
vision_model_name=args.vision_model,
proj_out_dim=args.proj_dim,
llava_ckpt_path=args.llava_ckpt
)
print(f"Using device: {DEVICE}")
# Process in batches
all_embs = []
index_to_file = {}
for batch_start in range(0, len(image_paths), args.batch_size):
batch_paths = image_paths[batch_start:batch_start + args.batch_size]
# load images
imgs = [Image.open(p).convert("RGB") for p in batch_paths]
# embed
with torch.no_grad():
embs = embedder.image_to_embedding(imgs) # (B, D)
embs_np = embs.cpu().numpy()
all_embs.append(embs_np)
# record mapping
for i, p in enumerate(batch_paths):
index_to_file[batch_start + i] = p
print(f" • Processed {batch_start + len(batch_paths)}/{len(image_paths)} images")
# Stack and save
vectors = np.vstack(all_embs) # shape (N, D)
vec_file = os.path.join(args.output_dir, "image_vectors.npy")
map_file = os.path.join(args.output_dir, "index_to_file.pkl")
np.save(vec_file, vectors)
with open(map_file, "wb") as f:
pickle.dump(index_to_file, f)
print(f"\nSaved {vectors.shape[0]}×{vectors.shape[1]} vectors to\n {vec_file}")
print(f"Saved index→file mapping to\n {map_file}")
if __name__ == "__main__":
main()