233 lines
7.4 KiB
Python
233 lines
7.4 KiB
Python
import os
|
|
import argparse
|
|
import pickle
|
|
import logging
|
|
from pathlib import Path
|
|
from typing import Dict
|
|
|
|
import numpy as np
|
|
from PIL import Image
|
|
import torch
|
|
from tqdm import tqdm
|
|
from atomicwrites import atomic_write
|
|
|
|
from embedder import ImageEmbedder, DEVICE
|
|
|
|
# Configure logging
|
|
logging.basicConfig(level=logging.INFO)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Default paths
|
|
DEFAULT_OUTPUT_DIR = Path("embeddings")
|
|
DEFAULT_EMB_FILE = "image_embeddings.pkl"
|
|
|
|
|
|
def parse_args():
|
|
parser = argparse.ArgumentParser(
|
|
description="Batch-convert images into 4096D embeddings stored in unified dictionary"
|
|
)
|
|
parser.add_argument(
|
|
"--image-dir", required=True,
|
|
help="Path to folder containing images (jpg/png/bmp/gif/webp)"
|
|
)
|
|
parser.add_argument(
|
|
"--output-dir", default=str(DEFAULT_OUTPUT_DIR),
|
|
help="Directory to save embeddings file"
|
|
)
|
|
parser.add_argument(
|
|
"--output-file", default=DEFAULT_EMB_FILE,
|
|
help="Filename for embeddings dictionary (pkl format)"
|
|
)
|
|
parser.add_argument(
|
|
"--vision-model", default="openai/clip-vit-large-patch14-336",
|
|
help="Hugging Face CLIP vision model name"
|
|
)
|
|
parser.add_argument(
|
|
"--embedding-dim", type=int, default=4096,
|
|
help="Output embedding dimension"
|
|
)
|
|
parser.add_argument(
|
|
"--llava-ckpt", default=None,
|
|
help="(Optional) LLaVA checkpoint to load projection weights from"
|
|
)
|
|
parser.add_argument(
|
|
"--batch-size", type=int, default=32,
|
|
help="Batch size for processing images"
|
|
)
|
|
parser.add_argument(
|
|
"--resume", action="store_true",
|
|
help="Resume from existing embeddings file (skip already processed images)"
|
|
)
|
|
return parser.parse_args()
|
|
|
|
|
|
def find_images(folder: str) -> list:
|
|
"""Find all supported image files in folder (non-recursive by default)"""
|
|
supported_exts = (".jpg", ".jpeg", ".png", ".bmp", ".gif", ".webp", ".tiff", ".tif")
|
|
image_paths = []
|
|
|
|
folder_path = Path(folder)
|
|
|
|
# Use a set to avoid duplicates
|
|
found_files = set()
|
|
|
|
# Search for files with supported extensions (case-insensitive)
|
|
for file_path in folder_path.iterdir():
|
|
if file_path.is_file():
|
|
file_ext = file_path.suffix.lower()
|
|
if file_ext in supported_exts:
|
|
found_files.add(str(file_path.resolve()))
|
|
|
|
return sorted(list(found_files))
|
|
|
|
|
|
def load_existing_embeddings(filepath: Path) -> Dict[str, np.ndarray]:
|
|
"""Load existing embeddings dictionary"""
|
|
if not filepath.exists():
|
|
return {}
|
|
|
|
try:
|
|
with open(filepath, "rb") as f:
|
|
embeddings = pickle.load(f)
|
|
logger.info(f"Loaded {len(embeddings)} existing embeddings from {filepath}")
|
|
return embeddings
|
|
except Exception as e:
|
|
logger.warning(f"Failed to load existing embeddings: {e}")
|
|
return {}
|
|
|
|
|
|
def save_embeddings(embeddings: Dict[str, np.ndarray], filepath: Path):
|
|
"""Safely save embeddings dictionary"""
|
|
filepath.parent.mkdir(parents=True, exist_ok=True)
|
|
|
|
with atomic_write(filepath, mode='wb', overwrite=True) as f:
|
|
pickle.dump(embeddings, f, protocol=pickle.HIGHEST_PROTOCOL)
|
|
|
|
logger.info(f"Saved {len(embeddings)} embeddings to {filepath}")
|
|
|
|
|
|
def main():
|
|
args = parse_args()
|
|
|
|
# Setup paths
|
|
output_dir = Path(args.output_dir)
|
|
output_file = output_dir / args.output_file
|
|
|
|
# Discover images
|
|
logger.info(f"Scanning for images in {args.image_dir}")
|
|
image_paths = find_images(args.image_dir)
|
|
|
|
if not image_paths:
|
|
raise RuntimeError(f"No supported images found in {args.image_dir}")
|
|
|
|
logger.info(f"Found {len(image_paths)} images")
|
|
|
|
# Load existing embeddings if resuming
|
|
embeddings_dict = {}
|
|
if args.resume:
|
|
embeddings_dict = load_existing_embeddings(output_file)
|
|
|
|
# Filter out already processed images
|
|
new_image_paths = [p for p in image_paths if p not in embeddings_dict]
|
|
logger.info(f"Resuming: {len(embeddings_dict)} already processed, {len(new_image_paths)} remaining")
|
|
image_paths = new_image_paths
|
|
|
|
if not image_paths:
|
|
logger.info("All images already processed!")
|
|
return
|
|
|
|
# Initialize embedder
|
|
logger.info(f"Initializing ImageEmbedder on {DEVICE}")
|
|
embedder = ImageEmbedder(
|
|
vision_model_name=args.vision_model,
|
|
proj_out_dim=args.embedding_dim,
|
|
llava_ckpt_path=args.llava_ckpt
|
|
)
|
|
|
|
# Show memory usage
|
|
mem = embedder.get_memory_usage()
|
|
logger.info(f"GPU Memory - Allocated: {mem['allocated']:.2f}GB, Reserved: {mem['reserved']:.2f}GB")
|
|
|
|
# Process images in batches
|
|
processed_count = 0
|
|
failed_count = 0
|
|
|
|
progress_bar = tqdm(total=len(image_paths), desc="Processing images")
|
|
|
|
for batch_start in range(0, len(image_paths), args.batch_size):
|
|
batch_end = min(batch_start + args.batch_size, len(image_paths))
|
|
batch_paths = image_paths[batch_start:batch_end]
|
|
|
|
# Load batch images
|
|
batch_images = []
|
|
valid_paths = []
|
|
|
|
for img_path in batch_paths:
|
|
try:
|
|
img = Image.open(img_path).convert("RGB")
|
|
batch_images.append(img)
|
|
valid_paths.append(img_path)
|
|
except Exception as e:
|
|
logger.warning(f"Failed to load {img_path}: {e}")
|
|
failed_count += 1
|
|
progress_bar.update(1)
|
|
continue
|
|
|
|
if not batch_images:
|
|
continue
|
|
|
|
# Generate embeddings for batch
|
|
try:
|
|
with torch.no_grad():
|
|
if len(batch_images) == 1:
|
|
embeddings = embedder.image_to_embedding(batch_images[0]).cpu().numpy()
|
|
embeddings = embeddings.reshape(1, -1) # Make it (1, D)
|
|
else:
|
|
embeddings = embedder.image_to_embedding(batch_images).cpu().numpy() # (B, D)
|
|
|
|
# Store embeddings in dictionary
|
|
for img_path, embedding in zip(valid_paths, embeddings):
|
|
embeddings_dict[img_path] = embedding
|
|
processed_count += 1
|
|
|
|
progress_bar.update(len(valid_paths))
|
|
|
|
# Periodic save every 10 batches
|
|
if (batch_start // args.batch_size) % 10 == 0:
|
|
save_embeddings(embeddings_dict, output_file)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to process batch starting at {batch_start}: {e}")
|
|
failed_count += len(batch_paths)
|
|
progress_bar.update(len(batch_paths))
|
|
continue
|
|
|
|
progress_bar.close()
|
|
|
|
# Final save
|
|
save_embeddings(embeddings_dict, output_file)
|
|
|
|
# Summary
|
|
total_processed = len(embeddings_dict)
|
|
logger.info(f"\nProcessing complete!")
|
|
logger.info(f"Total embeddings: {total_processed}")
|
|
logger.info(f"Newly processed: {processed_count}")
|
|
logger.info(f"Failed: {failed_count}")
|
|
logger.info(f"Embedding dimension: {args.embedding_dim}")
|
|
logger.info(f"Output file: {output_file}")
|
|
|
|
# Verify a sample embedding
|
|
if embeddings_dict:
|
|
sample_path = next(iter(embeddings_dict))
|
|
sample_emb = embeddings_dict[sample_path]
|
|
logger.info(f"Sample embedding shape: {sample_emb.shape}")
|
|
logger.info(f"Sample embedding norm: {np.linalg.norm(sample_emb):.4f}")
|
|
|
|
# Final memory usage
|
|
mem = embedder.get_memory_usage()
|
|
logger.info(f"Final GPU Memory - Allocated: {mem['allocated']:.2f}GB, Reserved: {mem['reserved']:.2f}GB")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|