Img2Vec/text_embedder.py
2025-06-13 07:21:58 +08:00

51 lines
1.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import torch
from transformers import AutoTokenizer, AutoModel
from typing import List
# Use the same DEVICE as embedder
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class TextEmbedder:
"""
Encodes text into the same embedding space (assumes your LLM was aligned
with LLaVAs projector during fine-tuning).
"""
def __init__(self,
model_name: str,
tokenizer_name: str = None):
tokenizer_name = tokenizer_name or model_name
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
self.text_model = (
AutoModel
.from_pretrained(model_name)
.to(DEVICE)
.eval()
)
for p in self.text_model.parameters():
p.requires_grad = False
def text_to_embedding(self, texts: List[str]) -> torch.Tensor:
"""
Returns a tensor of shape (batch_size, hidden_dim) on DEVICE.
Uses pooler_output if available; otherwise mean-pools tokens.
"""
inputs = self.tokenizer(
texts,
padding=True,
truncation=True,
return_tensors="pt"
)
# move all inputs to GPU if available
inputs = {k: v.to(DEVICE) for k, v in inputs.items()}
with torch.no_grad():
outputs = self.text_model(**inputs)
if hasattr(outputs, "pooler_output") and outputs.pooler_output is not None:
emb = outputs.pooler_output # (batch, hidden_dim)
else:
emb = outputs.last_hidden_state.mean(dim=1)
return emb