Img2Vec/embedder.py

190 lines
7.2 KiB
Python

import torch
import torch.nn as nn
from transformers import CLIPImageProcessor, CLIPVisionModel
from PIL import Image
import logging
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Use CUDA if available
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class ImageEmbedder:
"""
Image embedder that outputs 4096-dimensional embeddings using CLIP vision model.
Maintains compatibility with existing LLaVA weights while reducing dimension.
"""
def __init__(self,
vision_model_name: str = "openai/clip-vit-large-patch14-336",
proj_out_dim: int = 4096, # Changed from 5120 to 4096
llava_ckpt_path: str = None):
logger.info(f"Initializing ImageEmbedder with {vision_model_name}")
# Load CLIP vision encoder + processor (keep exactly the same)
self.processor = CLIPImageProcessor.from_pretrained(vision_model_name)
self.vision_model = (
CLIPVisionModel
.from_pretrained(vision_model_name)
.to(DEVICE)
.eval()
)
# Freeze vision model parameters
for p in self.vision_model.parameters():
p.requires_grad = False
# Get vision model's hidden dimension
vision_hidden_dim = self.vision_model.config.hidden_size # should be 1024
# Keep the simple projection architecture that works!
# Just change output dimension from 5120 to 4096
self.projection = nn.Linear(vision_hidden_dim, proj_out_dim).to(DEVICE)
self.projection.eval()
# Load LLaVA's projection weights if provided
if llava_ckpt_path is not None:
self._load_llava_weights(llava_ckpt_path, vision_hidden_dim, proj_out_dim)
logger.info(f"ImageEmbedder ready: {vision_hidden_dim}D -> {proj_out_dim}D on {DEVICE}")
def _load_llava_weights(self, ckpt_path, vision_dim, output_dim):
"""Load LLaVA projection weights, adapting for different output dimension"""
try:
ckpt = torch.load(ckpt_path, map_location=DEVICE)
# Extract projector weights (same as your original code)
keys = ["model.mm_projector.weight", "model.mm_projector.bias"]
if all(k in ckpt for k in keys):
original_weight = ckpt["model.mm_projector.weight"] # Shape: (5120, 1024)
original_bias = ckpt["model.mm_projector.bias"] # Shape: (5120,)
if output_dim == 5120:
# Direct loading for original dimension
state = {
"weight": original_weight,
"bias": original_bias
}
self.projection.load_state_dict(state)
logger.info("Loaded original LLaVA projection weights")
elif output_dim < 5120:
# Truncate weights for smaller dimension (preserves most important features)
truncated_weight = original_weight[:output_dim, :] # (4096, 1024)
truncated_bias = original_bias[:output_dim] # (4096,)
state = {
"weight": truncated_weight,
"bias": truncated_bias
}
self.projection.load_state_dict(state)
logger.info(f"Loaded truncated LLaVA weights: {5120}D -> {output_dim}D")
else:
# For larger dimensions, pad with Xavier initialization
new_weight = torch.zeros(output_dim, vision_dim, device=DEVICE)
new_bias = torch.zeros(output_dim, device=DEVICE)
# Copy original weights
new_weight[:5120, :] = original_weight
new_bias[:5120] = original_bias
# Initialize additional weights
nn.init.xavier_uniform_(new_weight[5120:, :])
nn.init.zeros_(new_bias[5120:])
state = {
"weight": new_weight,
"bias": new_bias
}
self.projection.load_state_dict(state)
logger.info(f"Loaded extended LLaVA weights: {5120}D -> {output_dim}D")
else:
logger.warning("LLaVA projector weights not found, using random initialization")
except Exception as e:
logger.warning(f"Failed to load LLaVA weights: {e}, using random initialization")
def image_to_embedding(self, image_input) -> torch.Tensor:
"""
Convert image(s) to embeddings - keeping your exact original logic!
Args:
image_input: PIL.Image or list of PIL.Images
Returns:
torch.Tensor: Embeddings of shape (proj_out_dim,) for single image
or (B, proj_out_dim) for batch of images
"""
# Handle single image vs batch (same as your original)
if isinstance(image_input, Image.Image):
images = [image_input]
single_image = True
else:
images = image_input
single_image = False
# Preprocess images (same as your original)
inputs = self.processor(images=images, return_tensors="pt")
pixel_values = inputs.pixel_values.to(DEVICE) # (B, 3, H, W)
# Forward pass (exactly like your original)
with torch.no_grad():
out = self.vision_model(pixel_values=pixel_values)
feat = out.pooler_output # (B, 1024)
emb = self.projection(feat) # (B, proj_out_dim)
# Return appropriate shape (same as your original)
if single_image:
return emb.squeeze(0) # (proj_out_dim,)
else:
return emb # (B, proj_out_dim)
def get_memory_usage(self):
"""Get current GPU memory usage if available"""
if torch.cuda.is_available():
return {
'allocated': torch.cuda.memory_allocated() / 1024 ** 3, # GB
'reserved': torch.cuda.memory_reserved() / 1024 ** 3, # GB
}
return {'allocated': 0, 'reserved': 0}
# Test function
def main():
"""Simple test of the embedder"""
import argparse
from pathlib import Path
parser = argparse.ArgumentParser(description="Test ImageEmbedder")
parser.add_argument("--image", type=str, help="Path to test image")
parser.add_argument("--model", type=str, default="openai/clip-vit-large-patch14-336")
parser.add_argument("--dim", type=int, default=4096)
parser.add_argument("--llava-ckpt", type=str, help="Path to LLaVA checkpoint")
args = parser.parse_args()
embedder = ImageEmbedder(
vision_model_name=args.model,
proj_out_dim=args.dim,
llava_ckpt_path=args.llava_ckpt
)
if args.image:
img = Image.open(args.image).convert("RGB")
emb = embedder.image_to_embedding(img)
print(f"Embedding shape: {emb.shape}")
print(f"Sample values: {emb[:10]}")
# Show memory usage
mem = embedder.get_memory_usage()
logger.info(f"GPU Memory - Allocated: {mem['allocated']:.2f}GB, Reserved: {mem['reserved']:.2f}GB")
if __name__ == "__main__":
main()