51 lines
1.6 KiB
Python
51 lines
1.6 KiB
Python
import torch
|
||
from transformers import AutoTokenizer, AutoModel
|
||
from typing import List
|
||
|
||
# Use the same DEVICE as embedder
|
||
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||
|
||
|
||
class TextEmbedder:
|
||
"""
|
||
Encodes text into the same embedding space (assumes your LLM was aligned
|
||
with LLaVA’s projector during fine-tuning).
|
||
"""
|
||
|
||
def __init__(self,
|
||
model_name: str,
|
||
tokenizer_name: str = None):
|
||
tokenizer_name = tokenizer_name or model_name
|
||
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
|
||
self.text_model = (
|
||
AutoModel
|
||
.from_pretrained(model_name)
|
||
.to(DEVICE)
|
||
.eval()
|
||
)
|
||
for p in self.text_model.parameters():
|
||
p.requires_grad = False
|
||
|
||
def text_to_embedding(self, texts: List[str]) -> torch.Tensor:
|
||
"""
|
||
Returns a tensor of shape (batch_size, hidden_dim) on DEVICE.
|
||
Uses pooler_output if available; otherwise mean-pools tokens.
|
||
"""
|
||
inputs = self.tokenizer(
|
||
texts,
|
||
padding=True,
|
||
truncation=True,
|
||
return_tensors="pt"
|
||
)
|
||
# move all inputs to GPU if available
|
||
inputs = {k: v.to(DEVICE) for k, v in inputs.items()}
|
||
|
||
with torch.no_grad():
|
||
outputs = self.text_model(**inputs)
|
||
if hasattr(outputs, "pooler_output") and outputs.pooler_output is not None:
|
||
emb = outputs.pooler_output # (batch, hidden_dim)
|
||
else:
|
||
emb = outputs.last_hidden_state.mean(dim=1)
|
||
|
||
return emb
|