text2img/embed/text_embedder.py
2025-06-18 16:07:02 +08:00

139 lines
4.5 KiB
Python

import pickle
import logging
from pathlib import Path
from typing import Dict, Union
import torch
import torch.nn as nn
import numpy as np
from sentence_transformers import SentenceTransformer
from cachetools import LRUCache
from tqdm import tqdm
from atomicwrites import atomic_write
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Device & defaults
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Use a model that's good for semantic similarity
DEFAULT_MODEL = "sentence-transformers/clip-ViT-B-32-multilingual-v1" # CLIP-based, good for images
OUTPUT_DIM = 4096
ARTICLE_CACHE_SIZE = 1024
# Persistence paths
EMBEDDINGS_DIR = Path(__file__).parent.parent / "embeddings"
TEXT_EMB_PATH = EMBEDDINGS_DIR / "text_embeddings.pkl"
class SimpleSentenceEmbedder:
"""
Simple text embedder using sentence-transformers with CLIP backbone.
This should give much better text-image alignment.
"""
def __init__(
self,
model_name: str = DEFAULT_MODEL,
output_dim: int = OUTPUT_DIM,
):
# Load sentence transformer model
self.model = SentenceTransformer(model_name, device=DEVICE)
# Get the model's output dimension
model_dim = self.model.get_sentence_embedding_dimension()
# Simple projector to match your 4096-D image embeddings
self.projector = nn.Sequential(
nn.Linear(model_dim, output_dim),
nn.ReLU(),
nn.Linear(output_dim, output_dim),
).to(DEVICE)
# Better initialization
for layer in self.projector:
if isinstance(layer, nn.Linear):
nn.init.xavier_uniform_(layer.weight)
nn.init.zeros_(layer.bias)
self.article_cache = LRUCache(maxsize=ARTICLE_CACHE_SIZE)
# Ensure embeddings dir
EMBEDDINGS_DIR.mkdir(parents=True, exist_ok=True)
logger.info(f"Initialized SentenceEmbedder with {model_name} ({model_dim}D -> {output_dim}D)")
def text_to_embedding(self, text: str) -> torch.Tensor:
"""
Embed text using sentence transformer + simple projector
"""
if text in self.article_cache:
return self.article_cache[text]
# Get sentence embedding
with torch.no_grad():
sentence_emb = self.model.encode([text], convert_to_tensor=True, device=DEVICE)
sentence_emb = nn.functional.normalize(sentence_emb, p=2, dim=1)
# Project to target dimension
projected = self.projector(sentence_emb)
projected = nn.functional.normalize(projected, p=2, dim=1)
result = projected.squeeze(0).cpu()
self.article_cache[text] = result
return result
def embed_file(
self,
text_file: Union[str, Path],
save_path: Union[str, Path] = TEXT_EMB_PATH
) -> Dict[str, np.ndarray]:
"""
Read newline-separated articles and embed them
"""
save_path = Path(save_path)
save_path.parent.mkdir(parents=True, exist_ok=True)
try:
with open(save_path, "rb") as fp:
mapping: Dict[str, np.ndarray] = pickle.load(fp)
except FileNotFoundError:
mapping = {}
except Exception as e:
logger.warning(f"Failed loading {save_path}: {e}")
mapping = {}
try:
with open(text_file, "r", encoding="utf-8") as fp:
lines = [ln.strip() for ln in fp if ln.strip()]
except Exception as e:
raise RuntimeError(f"Error reading {text_file}: {e}")
for ln in tqdm(lines, desc="Embedding articles"):
if ln not in mapping:
try:
vec = self.text_to_embedding(ln)
mapping[ln] = vec.cpu().numpy()
except Exception as e:
logger.error(f"Embedding failed for article: {e}")
with atomic_write(save_path, mode='wb', overwrite=True) as f:
pickle.dump(mapping, f, protocol=pickle.HIGHEST_PROTOCOL)
return mapping
def get_memory_usage(self):
"""Get current GPU memory usage if available."""
if torch.cuda.is_available():
return {
'allocated': torch.cuda.memory_allocated() / 1024 ** 3, # GB
'reserved': torch.cuda.memory_reserved() / 1024 ** 3, # GB
}
return {'allocated': 0, 'reserved': 0}
# Alias for compatibility
TextEmbedder = SimpleSentenceEmbedder