import torch from transformers import AutoTokenizer, AutoModel from typing import List # Use the same DEVICE as embedder DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") class TextEmbedder: """ Encodes text into the same embedding space (assumes your LLM was aligned with LLaVA’s projector during fine-tuning). """ def __init__(self, model_name: str, tokenizer_name: str = None): tokenizer_name = tokenizer_name or model_name self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) self.text_model = ( AutoModel .from_pretrained(model_name) .to(DEVICE) .eval() ) for p in self.text_model.parameters(): p.requires_grad = False def text_to_embedding(self, texts: List[str]) -> torch.Tensor: """ Returns a tensor of shape (batch_size, hidden_dim) on DEVICE. Uses pooler_output if available; otherwise mean-pools tokens. """ inputs = self.tokenizer( texts, padding=True, truncation=True, return_tensors="pt" ) # move all inputs to GPU if available inputs = {k: v.to(DEVICE) for k, v in inputs.items()} with torch.no_grad(): outputs = self.text_model(**inputs) if hasattr(outputs, "pooler_output") and outputs.pooler_output is not None: emb = outputs.pooler_output # (batch, hidden_dim) else: emb = outputs.last_hidden_state.mean(dim=1) return emb