Img2Vec/starter.py

233 lines
7.4 KiB
Python

import os
import argparse
import pickle
import logging
from pathlib import Path
from typing import Dict
import numpy as np
from PIL import Image
import torch
from tqdm import tqdm
from atomicwrites import atomic_write
from embedder import ImageEmbedder, DEVICE
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Default paths
DEFAULT_OUTPUT_DIR = Path("embeddings")
DEFAULT_EMB_FILE = "image_embeddings.pkl"
def parse_args():
parser = argparse.ArgumentParser(
description="Batch-convert images into 4096D embeddings stored in unified dictionary"
)
parser.add_argument(
"--image-dir", required=True,
help="Path to folder containing images (jpg/png/bmp/gif/webp)"
)
parser.add_argument(
"--output-dir", default=str(DEFAULT_OUTPUT_DIR),
help="Directory to save embeddings file"
)
parser.add_argument(
"--output-file", default=DEFAULT_EMB_FILE,
help="Filename for embeddings dictionary (pkl format)"
)
parser.add_argument(
"--vision-model", default="openai/clip-vit-large-patch14-336",
help="Hugging Face CLIP vision model name"
)
parser.add_argument(
"--embedding-dim", type=int, default=4096,
help="Output embedding dimension"
)
parser.add_argument(
"--llava-ckpt", default=None,
help="(Optional) LLaVA checkpoint to load projection weights from"
)
parser.add_argument(
"--batch-size", type=int, default=32,
help="Batch size for processing images"
)
parser.add_argument(
"--resume", action="store_true",
help="Resume from existing embeddings file (skip already processed images)"
)
return parser.parse_args()
def find_images(folder: str) -> list:
"""Find all supported image files in folder (non-recursive by default)"""
supported_exts = (".jpg", ".jpeg", ".png", ".bmp", ".gif", ".webp", ".tiff", ".tif")
image_paths = []
folder_path = Path(folder)
# Use a set to avoid duplicates
found_files = set()
# Search for files with supported extensions (case-insensitive)
for file_path in folder_path.iterdir():
if file_path.is_file():
file_ext = file_path.suffix.lower()
if file_ext in supported_exts:
found_files.add(str(file_path.resolve()))
return sorted(list(found_files))
def load_existing_embeddings(filepath: Path) -> Dict[str, np.ndarray]:
"""Load existing embeddings dictionary"""
if not filepath.exists():
return {}
try:
with open(filepath, "rb") as f:
embeddings = pickle.load(f)
logger.info(f"Loaded {len(embeddings)} existing embeddings from {filepath}")
return embeddings
except Exception as e:
logger.warning(f"Failed to load existing embeddings: {e}")
return {}
def save_embeddings(embeddings: Dict[str, np.ndarray], filepath: Path):
"""Safely save embeddings dictionary"""
filepath.parent.mkdir(parents=True, exist_ok=True)
with atomic_write(filepath, mode='wb', overwrite=True) as f:
pickle.dump(embeddings, f, protocol=pickle.HIGHEST_PROTOCOL)
logger.info(f"Saved {len(embeddings)} embeddings to {filepath}")
def main():
args = parse_args()
# Setup paths
output_dir = Path(args.output_dir)
output_file = output_dir / args.output_file
# Discover images
logger.info(f"Scanning for images in {args.image_dir}")
image_paths = find_images(args.image_dir)
if not image_paths:
raise RuntimeError(f"No supported images found in {args.image_dir}")
logger.info(f"Found {len(image_paths)} images")
# Load existing embeddings if resuming
embeddings_dict = {}
if args.resume:
embeddings_dict = load_existing_embeddings(output_file)
# Filter out already processed images
new_image_paths = [p for p in image_paths if p not in embeddings_dict]
logger.info(f"Resuming: {len(embeddings_dict)} already processed, {len(new_image_paths)} remaining")
image_paths = new_image_paths
if not image_paths:
logger.info("All images already processed!")
return
# Initialize embedder
logger.info(f"Initializing ImageEmbedder on {DEVICE}")
embedder = ImageEmbedder(
vision_model_name=args.vision_model,
proj_out_dim=args.embedding_dim,
llava_ckpt_path=args.llava_ckpt
)
# Show memory usage
mem = embedder.get_memory_usage()
logger.info(f"GPU Memory - Allocated: {mem['allocated']:.2f}GB, Reserved: {mem['reserved']:.2f}GB")
# Process images in batches
processed_count = 0
failed_count = 0
progress_bar = tqdm(total=len(image_paths), desc="Processing images")
for batch_start in range(0, len(image_paths), args.batch_size):
batch_end = min(batch_start + args.batch_size, len(image_paths))
batch_paths = image_paths[batch_start:batch_end]
# Load batch images
batch_images = []
valid_paths = []
for img_path in batch_paths:
try:
img = Image.open(img_path).convert("RGB")
batch_images.append(img)
valid_paths.append(img_path)
except Exception as e:
logger.warning(f"Failed to load {img_path}: {e}")
failed_count += 1
progress_bar.update(1)
continue
if not batch_images:
continue
# Generate embeddings for batch
try:
with torch.no_grad():
if len(batch_images) == 1:
embeddings = embedder.image_to_embedding(batch_images[0]).cpu().numpy()
embeddings = embeddings.reshape(1, -1) # Make it (1, D)
else:
embeddings = embedder.image_to_embedding(batch_images).cpu().numpy() # (B, D)
# Store embeddings in dictionary
for img_path, embedding in zip(valid_paths, embeddings):
embeddings_dict[img_path] = embedding
processed_count += 1
progress_bar.update(len(valid_paths))
# Periodic save every 10 batches
if (batch_start // args.batch_size) % 10 == 0:
save_embeddings(embeddings_dict, output_file)
except Exception as e:
logger.error(f"Failed to process batch starting at {batch_start}: {e}")
failed_count += len(batch_paths)
progress_bar.update(len(batch_paths))
continue
progress_bar.close()
# Final save
save_embeddings(embeddings_dict, output_file)
# Summary
total_processed = len(embeddings_dict)
logger.info(f"\nProcessing complete!")
logger.info(f"Total embeddings: {total_processed}")
logger.info(f"Newly processed: {processed_count}")
logger.info(f"Failed: {failed_count}")
logger.info(f"Embedding dimension: {args.embedding_dim}")
logger.info(f"Output file: {output_file}")
# Verify a sample embedding
if embeddings_dict:
sample_path = next(iter(embeddings_dict))
sample_emb = embeddings_dict[sample_path]
logger.info(f"Sample embedding shape: {sample_emb.shape}")
logger.info(f"Sample embedding norm: {np.linalg.norm(sample_emb):.4f}")
# Final memory usage
mem = embedder.get_memory_usage()
logger.info(f"Final GPU Memory - Allocated: {mem['allocated']:.2f}GB, Reserved: {mem['reserved']:.2f}GB")
if __name__ == "__main__":
main()