import os import argparse import pickle import logging from pathlib import Path from typing import Dict import numpy as np from PIL import Image import torch from tqdm import tqdm from atomicwrites import atomic_write from embedder import ImageEmbedder, DEVICE # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Default paths DEFAULT_OUTPUT_DIR = Path("embeddings") DEFAULT_EMB_FILE = "image_embeddings.pkl" def parse_args(): parser = argparse.ArgumentParser( description="Batch-convert images into 4096D embeddings stored in unified dictionary" ) parser.add_argument( "--image-dir", required=True, help="Path to folder containing images (jpg/png/bmp/gif/webp)" ) parser.add_argument( "--output-dir", default=str(DEFAULT_OUTPUT_DIR), help="Directory to save embeddings file" ) parser.add_argument( "--output-file", default=DEFAULT_EMB_FILE, help="Filename for embeddings dictionary (pkl format)" ) parser.add_argument( "--vision-model", default="openai/clip-vit-large-patch14-336", help="Hugging Face CLIP vision model name" ) parser.add_argument( "--embedding-dim", type=int, default=4096, help="Output embedding dimension" ) parser.add_argument( "--llava-ckpt", default=None, help="(Optional) LLaVA checkpoint to load projection weights from" ) parser.add_argument( "--batch-size", type=int, default=32, help="Batch size for processing images" ) parser.add_argument( "--resume", action="store_true", help="Resume from existing embeddings file (skip already processed images)" ) return parser.parse_args() def find_images(folder: str) -> list: """Find all supported image files in folder (non-recursive by default)""" supported_exts = (".jpg", ".jpeg", ".png", ".bmp", ".gif", ".webp", ".tiff", ".tif") image_paths = [] folder_path = Path(folder) # Use a set to avoid duplicates found_files = set() # Search for files with supported extensions (case-insensitive) for file_path in folder_path.iterdir(): if file_path.is_file(): file_ext = file_path.suffix.lower() if file_ext in supported_exts: found_files.add(str(file_path.resolve())) return sorted(list(found_files)) def load_existing_embeddings(filepath: Path) -> Dict[str, np.ndarray]: """Load existing embeddings dictionary""" if not filepath.exists(): return {} try: with open(filepath, "rb") as f: embeddings = pickle.load(f) logger.info(f"Loaded {len(embeddings)} existing embeddings from {filepath}") return embeddings except Exception as e: logger.warning(f"Failed to load existing embeddings: {e}") return {} def save_embeddings(embeddings: Dict[str, np.ndarray], filepath: Path): """Safely save embeddings dictionary""" filepath.parent.mkdir(parents=True, exist_ok=True) with atomic_write(filepath, mode='wb', overwrite=True) as f: pickle.dump(embeddings, f, protocol=pickle.HIGHEST_PROTOCOL) logger.info(f"Saved {len(embeddings)} embeddings to {filepath}") def main(): args = parse_args() # Setup paths output_dir = Path(args.output_dir) output_file = output_dir / args.output_file # Discover images logger.info(f"Scanning for images in {args.image_dir}") image_paths = find_images(args.image_dir) if not image_paths: raise RuntimeError(f"No supported images found in {args.image_dir}") logger.info(f"Found {len(image_paths)} images") # Load existing embeddings if resuming embeddings_dict = {} if args.resume: embeddings_dict = load_existing_embeddings(output_file) # Filter out already processed images new_image_paths = [p for p in image_paths if p not in embeddings_dict] logger.info(f"Resuming: {len(embeddings_dict)} already processed, {len(new_image_paths)} remaining") image_paths = new_image_paths if not image_paths: logger.info("All images already processed!") return # Initialize embedder logger.info(f"Initializing ImageEmbedder on {DEVICE}") embedder = ImageEmbedder( vision_model_name=args.vision_model, proj_out_dim=args.embedding_dim, llava_ckpt_path=args.llava_ckpt ) # Show memory usage mem = embedder.get_memory_usage() logger.info(f"GPU Memory - Allocated: {mem['allocated']:.2f}GB, Reserved: {mem['reserved']:.2f}GB") # Process images in batches processed_count = 0 failed_count = 0 progress_bar = tqdm(total=len(image_paths), desc="Processing images") for batch_start in range(0, len(image_paths), args.batch_size): batch_end = min(batch_start + args.batch_size, len(image_paths)) batch_paths = image_paths[batch_start:batch_end] # Load batch images batch_images = [] valid_paths = [] for img_path in batch_paths: try: img = Image.open(img_path).convert("RGB") batch_images.append(img) valid_paths.append(img_path) except Exception as e: logger.warning(f"Failed to load {img_path}: {e}") failed_count += 1 progress_bar.update(1) continue if not batch_images: continue # Generate embeddings for batch try: with torch.no_grad(): if len(batch_images) == 1: embeddings = embedder.image_to_embedding(batch_images[0]).cpu().numpy() embeddings = embeddings.reshape(1, -1) # Make it (1, D) else: embeddings = embedder.image_to_embedding(batch_images).cpu().numpy() # (B, D) # Store embeddings in dictionary for img_path, embedding in zip(valid_paths, embeddings): embeddings_dict[img_path] = embedding processed_count += 1 progress_bar.update(len(valid_paths)) # Periodic save every 10 batches if (batch_start // args.batch_size) % 10 == 0: save_embeddings(embeddings_dict, output_file) except Exception as e: logger.error(f"Failed to process batch starting at {batch_start}: {e}") failed_count += len(batch_paths) progress_bar.update(len(batch_paths)) continue progress_bar.close() # Final save save_embeddings(embeddings_dict, output_file) # Summary total_processed = len(embeddings_dict) logger.info(f"\nProcessing complete!") logger.info(f"Total embeddings: {total_processed}") logger.info(f"Newly processed: {processed_count}") logger.info(f"Failed: {failed_count}") logger.info(f"Embedding dimension: {args.embedding_dim}") logger.info(f"Output file: {output_file}") # Verify a sample embedding if embeddings_dict: sample_path = next(iter(embeddings_dict)) sample_emb = embeddings_dict[sample_path] logger.info(f"Sample embedding shape: {sample_emb.shape}") logger.info(f"Sample embedding norm: {np.linalg.norm(sample_emb):.4f}") # Final memory usage mem = embedder.get_memory_usage() logger.info(f"Final GPU Memory - Allocated: {mem['allocated']:.2f}GB, Reserved: {mem['reserved']:.2f}GB") if __name__ == "__main__": main()