Img2Vec/embedder.py
2025-06-13 07:21:58 +08:00

63 lines
2.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import torch
from transformers import CLIPImageProcessor, CLIPVisionModel
from PIL import Image
# Force use of GPU 0 (or change “0” to whichever GPU index you want)
import os
print(torch.version.cuda)
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
# Now select device
if not torch.cuda.is_available():
raise RuntimeError("CUDA is not available—cannot run on GPU")
DEVICE = torch.device("cuda") # no fallback to CPU
print(f"→ Running exclusively on {torch.cuda.get_device_name(0)}")
class ImageEmbedder:
def __init__(self,
vision_model_name: str = "openai/clip-vit-large-patch14-336",
proj_out_dim: int = 5120,
llava_ckpt_path: str = None):
# Load CLIP vision encoder + processor
self.processor = CLIPImageProcessor.from_pretrained(vision_model_name)
self.vision_model = (
CLIPVisionModel
.from_pretrained(vision_model_name)
.to(DEVICE)
.eval()
)
# Freeze it
for p in self.vision_model.parameters():
p.requires_grad = False
# Build the projection layer (1024 → proj_out_dim) and move it to DEVICE
vision_hidden_dim = self.vision_model.config.hidden_size # should be 1024
self.projection = torch.nn.Linear(vision_hidden_dim, proj_out_dim).to(DEVICE)
self.projection.eval()
# If provided, load LLaVAs projection weights from the full bin checkpoint
if llava_ckpt_path is not None:
ckpt = torch.load(llava_ckpt_path, map_location=DEVICE)
# extract only the projector weights + bias
keys = ["model.mm_projector.weight", "model.mm_projector.bias"]
state = {
k.replace("model.mm_projector.", ""): ckpt[k]
for k in keys
}
# Load into our linear layer
self.projection.load_state_dict(state)
def image_to_embedding(self, image: Image.Image) -> torch.Tensor:
# preprocess & move to device
inputs = self.processor(images=image, return_tensors="pt")
pixel_values = inputs.pixel_values.to(DEVICE) # (1,3,H,W)
# Forward pass
with torch.no_grad():
out = self.vision_model(pixel_values=pixel_values)
feat = out.pooler_output # (1,1024)
emb = self.projection(feat) # (1,proj_out_dim)
# Returns a 1D tensor of size [proj_out_dim] on DEVICE
return emb.squeeze(0) # → (proj_out_dim,)