56 lines
2.1 KiB
Python
56 lines
2.1 KiB
Python
import torch
|
||
from transformers import CLIPImageProcessor, CLIPVisionModel
|
||
from PIL import Image
|
||
|
||
# Use CUDA
|
||
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||
|
||
|
||
class ImageEmbedder:
|
||
def __init__(self,
|
||
vision_model_name: str = "openai/clip-vit-large-patch14-336",
|
||
proj_out_dim: int = 5120,
|
||
llava_ckpt_path: str = None):
|
||
# Load CLIP vision encoder + processor
|
||
self.processor = CLIPImageProcessor.from_pretrained(vision_model_name)
|
||
self.vision_model = (
|
||
CLIPVisionModel
|
||
.from_pretrained(vision_model_name)
|
||
.to(DEVICE)
|
||
.eval()
|
||
)
|
||
# Freeze version
|
||
for p in self.vision_model.parameters():
|
||
p.requires_grad = False
|
||
|
||
# Build the projection layer (1024 → proj_out_dim) and move it to DEVICE
|
||
vision_hidden_dim = self.vision_model.config.hidden_size # should be 1024
|
||
self.projection = torch.nn.Linear(vision_hidden_dim, proj_out_dim).to(DEVICE)
|
||
self.projection.eval()
|
||
|
||
# Load LLaVA’s projection weights from the full bin checkpoint
|
||
if llava_ckpt_path is not None:
|
||
ckpt = torch.load(llava_ckpt_path, map_location=DEVICE)
|
||
# extract only the projector weights + bias
|
||
keys = ["model.mm_projector.weight", "model.mm_projector.bias"]
|
||
state = {
|
||
k.replace("model.mm_projector.", ""): ckpt[k]
|
||
for k in keys
|
||
}
|
||
# Load into our linear layer
|
||
self.projection.load_state_dict(state)
|
||
|
||
def image_to_embedding(self, image: Image.Image) -> torch.Tensor:
|
||
# preprocess & move to device
|
||
inputs = self.processor(images=image, return_tensors="pt")
|
||
pixel_values = inputs.pixel_values.to(DEVICE) # (1,3,H,W)
|
||
|
||
# Forward pass
|
||
with torch.no_grad():
|
||
out = self.vision_model(pixel_values=pixel_values)
|
||
feat = out.pooler_output # (1,1024)
|
||
emb = self.projection(feat) # (1,proj_out_dim)
|
||
|
||
# Returns a 1D tensor of size [proj_out_dim] on DEVICE.
|
||
return emb.squeeze(0) # → (proj_out_dim,)
|