import torch from transformers import CLIPImageProcessor, CLIPVisionModel from PIL import Image # Force use of GPU 0 (or change “0” to whichever GPU index you want) import os print(torch.version.cuda) os.environ["CUDA_VISIBLE_DEVICES"] = "0" # Now select device if not torch.cuda.is_available(): raise RuntimeError("CUDA is not available—cannot run on GPU") DEVICE = torch.device("cuda") # no fallback to CPU print(f"→ Running exclusively on {torch.cuda.get_device_name(0)}") class ImageEmbedder: def __init__(self, vision_model_name: str = "openai/clip-vit-large-patch14-336", proj_out_dim: int = 5120, llava_ckpt_path: str = None): # Load CLIP vision encoder + processor self.processor = CLIPImageProcessor.from_pretrained(vision_model_name) self.vision_model = ( CLIPVisionModel .from_pretrained(vision_model_name) .to(DEVICE) .eval() ) # Freeze it for p in self.vision_model.parameters(): p.requires_grad = False # Build the projection layer (1024 → proj_out_dim) and move it to DEVICE vision_hidden_dim = self.vision_model.config.hidden_size # should be 1024 self.projection = torch.nn.Linear(vision_hidden_dim, proj_out_dim).to(DEVICE) self.projection.eval() # If provided, load LLaVA’s projection weights from the full bin checkpoint if llava_ckpt_path is not None: ckpt = torch.load(llava_ckpt_path, map_location=DEVICE) # extract only the projector weights + bias keys = ["model.mm_projector.weight", "model.mm_projector.bias"] state = { k.replace("model.mm_projector.", ""): ckpt[k] for k in keys } # Load into our linear layer self.projection.load_state_dict(state) def image_to_embedding(self, image: Image.Image) -> torch.Tensor: # preprocess & move to device inputs = self.processor(images=image, return_tensors="pt") pixel_values = inputs.pixel_values.to(DEVICE) # (1,3,H,W) # Forward pass with torch.no_grad(): out = self.vision_model(pixel_values=pixel_values) feat = out.pooler_output # (1,1024) emb = self.projection(feat) # (1,proj_out_dim) # Returns a 1D tensor of size [proj_out_dim] on DEVICE return emb.squeeze(0) # → (proj_out_dim,)