import torch import torch.nn as nn from transformers import CLIPImageProcessor, CLIPVisionModel from PIL import Image import logging # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Use CUDA if available DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") class ImageEmbedder: """ Image embedder that outputs 4096-dimensional embeddings using CLIP vision model. Maintains compatibility with existing LLaVA weights while reducing dimension. """ def __init__(self, vision_model_name: str = "openai/clip-vit-large-patch14-336", proj_out_dim: int = 4096, # Changed from 5120 to 4096 llava_ckpt_path: str = None): logger.info(f"Initializing ImageEmbedder with {vision_model_name}") # Load CLIP vision encoder + processor (keep exactly the same) self.processor = CLIPImageProcessor.from_pretrained(vision_model_name) self.vision_model = ( CLIPVisionModel .from_pretrained(vision_model_name) .to(DEVICE) .eval() ) # Freeze vision model parameters for p in self.vision_model.parameters(): p.requires_grad = False # Get vision model's hidden dimension vision_hidden_dim = self.vision_model.config.hidden_size # should be 1024 # Keep the simple projection architecture that works! # Just change output dimension from 5120 to 4096 self.projection = nn.Linear(vision_hidden_dim, proj_out_dim).to(DEVICE) self.projection.eval() # Load LLaVA's projection weights if provided if llava_ckpt_path is not None: self._load_llava_weights(llava_ckpt_path, vision_hidden_dim, proj_out_dim) logger.info(f"ImageEmbedder ready: {vision_hidden_dim}D -> {proj_out_dim}D on {DEVICE}") def _load_llava_weights(self, ckpt_path, vision_dim, output_dim): """Load LLaVA projection weights, adapting for different output dimension""" try: ckpt = torch.load(ckpt_path, map_location=DEVICE) # Extract projector weights (same as your original code) keys = ["model.mm_projector.weight", "model.mm_projector.bias"] if all(k in ckpt for k in keys): original_weight = ckpt["model.mm_projector.weight"] # Shape: (5120, 1024) original_bias = ckpt["model.mm_projector.bias"] # Shape: (5120,) if output_dim == 5120: # Direct loading for original dimension state = { "weight": original_weight, "bias": original_bias } self.projection.load_state_dict(state) logger.info("Loaded original LLaVA projection weights") elif output_dim < 5120: # Truncate weights for smaller dimension (preserves most important features) truncated_weight = original_weight[:output_dim, :] # (4096, 1024) truncated_bias = original_bias[:output_dim] # (4096,) state = { "weight": truncated_weight, "bias": truncated_bias } self.projection.load_state_dict(state) logger.info(f"Loaded truncated LLaVA weights: {5120}D -> {output_dim}D") else: # For larger dimensions, pad with Xavier initialization new_weight = torch.zeros(output_dim, vision_dim, device=DEVICE) new_bias = torch.zeros(output_dim, device=DEVICE) # Copy original weights new_weight[:5120, :] = original_weight new_bias[:5120] = original_bias # Initialize additional weights nn.init.xavier_uniform_(new_weight[5120:, :]) nn.init.zeros_(new_bias[5120:]) state = { "weight": new_weight, "bias": new_bias } self.projection.load_state_dict(state) logger.info(f"Loaded extended LLaVA weights: {5120}D -> {output_dim}D") else: logger.warning("LLaVA projector weights not found, using random initialization") except Exception as e: logger.warning(f"Failed to load LLaVA weights: {e}, using random initialization") def image_to_embedding(self, image_input) -> torch.Tensor: """ Convert image(s) to embeddings - keeping your exact original logic! Args: image_input: PIL.Image or list of PIL.Images Returns: torch.Tensor: Embeddings of shape (proj_out_dim,) for single image or (B, proj_out_dim) for batch of images """ # Handle single image vs batch (same as your original) if isinstance(image_input, Image.Image): images = [image_input] single_image = True else: images = image_input single_image = False # Preprocess images (same as your original) inputs = self.processor(images=images, return_tensors="pt") pixel_values = inputs.pixel_values.to(DEVICE) # (B, 3, H, W) # Forward pass (exactly like your original) with torch.no_grad(): out = self.vision_model(pixel_values=pixel_values) feat = out.pooler_output # (B, 1024) emb = self.projection(feat) # (B, proj_out_dim) # Return appropriate shape (same as your original) if single_image: return emb.squeeze(0) # (proj_out_dim,) else: return emb # (B, proj_out_dim) def get_memory_usage(self): """Get current GPU memory usage if available""" if torch.cuda.is_available(): return { 'allocated': torch.cuda.memory_allocated() / 1024 ** 3, # GB 'reserved': torch.cuda.memory_reserved() / 1024 ** 3, # GB } return {'allocated': 0, 'reserved': 0} # Test function def main(): """Simple test of the embedder""" import argparse from pathlib import Path parser = argparse.ArgumentParser(description="Test ImageEmbedder") parser.add_argument("--image", type=str, help="Path to test image") parser.add_argument("--model", type=str, default="openai/clip-vit-large-patch14-336") parser.add_argument("--dim", type=int, default=4096) parser.add_argument("--llava-ckpt", type=str, help="Path to LLaVA checkpoint") args = parser.parse_args() embedder = ImageEmbedder( vision_model_name=args.model, proj_out_dim=args.dim, llava_ckpt_path=args.llava_ckpt ) if args.image: img = Image.open(args.image).convert("RGB") emb = embedder.image_to_embedding(img) print(f"Embedding shape: {emb.shape}") print(f"Sample values: {emb[:10]}") # Show memory usage mem = embedder.get_memory_usage() logger.info(f"GPU Memory - Allocated: {mem['allocated']:.2f}GB, Reserved: {mem['reserved']:.2f}GB") if __name__ == "__main__": main()