Img2Vec/embedder.py
2025-06-13 15:00:33 +08:00

56 lines
2.1 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import torch
from transformers import CLIPImageProcessor, CLIPVisionModel
from PIL import Image
# Use CUDA
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class ImageEmbedder:
def __init__(self,
vision_model_name: str = "openai/clip-vit-large-patch14-336",
proj_out_dim: int = 5120,
llava_ckpt_path: str = None):
# Load CLIP vision encoder + processor
self.processor = CLIPImageProcessor.from_pretrained(vision_model_name)
self.vision_model = (
CLIPVisionModel
.from_pretrained(vision_model_name)
.to(DEVICE)
.eval()
)
# Freeze version
for p in self.vision_model.parameters():
p.requires_grad = False
# Build the projection layer (1024 → proj_out_dim) and move it to DEVICE
vision_hidden_dim = self.vision_model.config.hidden_size # should be 1024
self.projection = torch.nn.Linear(vision_hidden_dim, proj_out_dim).to(DEVICE)
self.projection.eval()
# Load LLaVAs projection weights from the full bin checkpoint
if llava_ckpt_path is not None:
ckpt = torch.load(llava_ckpt_path, map_location=DEVICE)
# extract only the projector weights + bias
keys = ["model.mm_projector.weight", "model.mm_projector.bias"]
state = {
k.replace("model.mm_projector.", ""): ckpt[k]
for k in keys
}
# Load into our linear layer
self.projection.load_state_dict(state)
def image_to_embedding(self, image: Image.Image) -> torch.Tensor:
# preprocess & move to device
inputs = self.processor(images=image, return_tensors="pt")
pixel_values = inputs.pixel_values.to(DEVICE) # (1,3,H,W)
# Forward pass
with torch.no_grad():
out = self.vision_model(pixel_values=pixel_values)
feat = out.pooler_output # (1,1024)
emb = self.projection(feat) # (1,proj_out_dim)
# Returns a 1D tensor of size [proj_out_dim] on DEVICE.
return emb.squeeze(0) # → (proj_out_dim,)