139 lines
4.5 KiB
Python
139 lines
4.5 KiB
Python
import pickle
|
|
import logging
|
|
from pathlib import Path
|
|
from typing import Dict, Union
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import numpy as np
|
|
from sentence_transformers import SentenceTransformer
|
|
from cachetools import LRUCache
|
|
from tqdm import tqdm
|
|
from atomicwrites import atomic_write
|
|
|
|
# Configure logging
|
|
logging.basicConfig(level=logging.INFO)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Device & defaults
|
|
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
# Use a model that's good for semantic similarity
|
|
DEFAULT_MODEL = "sentence-transformers/clip-ViT-B-32-multilingual-v1" # CLIP-based, good for images
|
|
OUTPUT_DIM = 4096
|
|
ARTICLE_CACHE_SIZE = 1024
|
|
|
|
# Persistence paths
|
|
EMBEDDINGS_DIR = Path(__file__).parent.parent / "embeddings"
|
|
TEXT_EMB_PATH = EMBEDDINGS_DIR / "text_embeddings.pkl"
|
|
|
|
|
|
class SimpleSentenceEmbedder:
|
|
"""
|
|
Simple text embedder using sentence-transformers with CLIP backbone.
|
|
This should give much better text-image alignment.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
model_name: str = DEFAULT_MODEL,
|
|
output_dim: int = OUTPUT_DIM,
|
|
):
|
|
# Load sentence transformer model
|
|
self.model = SentenceTransformer(model_name, device=DEVICE)
|
|
|
|
# Get the model's output dimension
|
|
model_dim = self.model.get_sentence_embedding_dimension()
|
|
|
|
# Simple projector to match your 4096-D image embeddings
|
|
self.projector = nn.Sequential(
|
|
nn.Linear(model_dim, output_dim),
|
|
nn.ReLU(),
|
|
nn.Linear(output_dim, output_dim),
|
|
).to(DEVICE)
|
|
|
|
# Better initialization
|
|
for layer in self.projector:
|
|
if isinstance(layer, nn.Linear):
|
|
nn.init.xavier_uniform_(layer.weight)
|
|
nn.init.zeros_(layer.bias)
|
|
|
|
self.article_cache = LRUCache(maxsize=ARTICLE_CACHE_SIZE)
|
|
|
|
# Ensure embeddings dir
|
|
EMBEDDINGS_DIR.mkdir(parents=True, exist_ok=True)
|
|
|
|
logger.info(f"Initialized SentenceEmbedder with {model_name} ({model_dim}D -> {output_dim}D)")
|
|
|
|
def text_to_embedding(self, text: str) -> torch.Tensor:
|
|
"""
|
|
Embed text using sentence transformer + simple projector
|
|
"""
|
|
if text in self.article_cache:
|
|
return self.article_cache[text]
|
|
|
|
# Get sentence embedding
|
|
with torch.no_grad():
|
|
sentence_emb = self.model.encode([text], convert_to_tensor=True, device=DEVICE)
|
|
sentence_emb = nn.functional.normalize(sentence_emb, p=2, dim=1)
|
|
|
|
# Project to target dimension
|
|
projected = self.projector(sentence_emb)
|
|
projected = nn.functional.normalize(projected, p=2, dim=1)
|
|
|
|
result = projected.squeeze(0).cpu()
|
|
|
|
self.article_cache[text] = result
|
|
return result
|
|
|
|
def embed_file(
|
|
self,
|
|
text_file: Union[str, Path],
|
|
save_path: Union[str, Path] = TEXT_EMB_PATH
|
|
) -> Dict[str, np.ndarray]:
|
|
"""
|
|
Read newline-separated articles and embed them
|
|
"""
|
|
save_path = Path(save_path)
|
|
save_path.parent.mkdir(parents=True, exist_ok=True)
|
|
|
|
try:
|
|
with open(save_path, "rb") as fp:
|
|
mapping: Dict[str, np.ndarray] = pickle.load(fp)
|
|
except FileNotFoundError:
|
|
mapping = {}
|
|
except Exception as e:
|
|
logger.warning(f"Failed loading {save_path}: {e}")
|
|
mapping = {}
|
|
|
|
try:
|
|
with open(text_file, "r", encoding="utf-8") as fp:
|
|
lines = [ln.strip() for ln in fp if ln.strip()]
|
|
except Exception as e:
|
|
raise RuntimeError(f"Error reading {text_file}: {e}")
|
|
|
|
for ln in tqdm(lines, desc="Embedding articles"):
|
|
if ln not in mapping:
|
|
try:
|
|
vec = self.text_to_embedding(ln)
|
|
mapping[ln] = vec.cpu().numpy()
|
|
except Exception as e:
|
|
logger.error(f"Embedding failed for article: {e}")
|
|
|
|
with atomic_write(save_path, mode='wb', overwrite=True) as f:
|
|
pickle.dump(mapping, f, protocol=pickle.HIGHEST_PROTOCOL)
|
|
|
|
return mapping
|
|
|
|
def get_memory_usage(self):
|
|
"""Get current GPU memory usage if available."""
|
|
if torch.cuda.is_available():
|
|
return {
|
|
'allocated': torch.cuda.memory_allocated() / 1024 ** 3, # GB
|
|
'reserved': torch.cuda.memory_reserved() / 1024 ** 3, # GB
|
|
}
|
|
return {'allocated': 0, 'reserved': 0}
|
|
|
|
|
|
# Alias for compatibility
|
|
TextEmbedder = SimpleSentenceEmbedder
|