import pickle import logging from pathlib import Path from typing import Dict, Union import torch import torch.nn as nn import numpy as np from sentence_transformers import SentenceTransformer from cachetools import LRUCache from tqdm import tqdm from atomicwrites import atomic_write # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Device & defaults DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Use a model that's good for semantic similarity DEFAULT_MODEL = "sentence-transformers/clip-ViT-B-32-multilingual-v1" # CLIP-based, good for images OUTPUT_DIM = 4096 ARTICLE_CACHE_SIZE = 1024 # Persistence paths EMBEDDINGS_DIR = Path(__file__).parent.parent / "embeddings" TEXT_EMB_PATH = EMBEDDINGS_DIR / "text_embeddings.pkl" class SimpleSentenceEmbedder: """ Simple text embedder using sentence-transformers with CLIP backbone. This should give much better text-image alignment. """ def __init__( self, model_name: str = DEFAULT_MODEL, output_dim: int = OUTPUT_DIM, ): # Load sentence transformer model self.model = SentenceTransformer(model_name, device=DEVICE) # Get the model's output dimension model_dim = self.model.get_sentence_embedding_dimension() # Simple projector to match your 4096-D image embeddings self.projector = nn.Sequential( nn.Linear(model_dim, output_dim), nn.ReLU(), nn.Linear(output_dim, output_dim), ).to(DEVICE) # Better initialization for layer in self.projector: if isinstance(layer, nn.Linear): nn.init.xavier_uniform_(layer.weight) nn.init.zeros_(layer.bias) self.article_cache = LRUCache(maxsize=ARTICLE_CACHE_SIZE) # Ensure embeddings dir EMBEDDINGS_DIR.mkdir(parents=True, exist_ok=True) logger.info(f"Initialized SentenceEmbedder with {model_name} ({model_dim}D -> {output_dim}D)") def text_to_embedding(self, text: str) -> torch.Tensor: """ Embed text using sentence transformer + simple projector """ if text in self.article_cache: return self.article_cache[text] # Get sentence embedding with torch.no_grad(): sentence_emb = self.model.encode([text], convert_to_tensor=True, device=DEVICE) sentence_emb = nn.functional.normalize(sentence_emb, p=2, dim=1) # Project to target dimension projected = self.projector(sentence_emb) projected = nn.functional.normalize(projected, p=2, dim=1) result = projected.squeeze(0).cpu() self.article_cache[text] = result return result def embed_file( self, text_file: Union[str, Path], save_path: Union[str, Path] = TEXT_EMB_PATH ) -> Dict[str, np.ndarray]: """ Read newline-separated articles and embed them """ save_path = Path(save_path) save_path.parent.mkdir(parents=True, exist_ok=True) try: with open(save_path, "rb") as fp: mapping: Dict[str, np.ndarray] = pickle.load(fp) except FileNotFoundError: mapping = {} except Exception as e: logger.warning(f"Failed loading {save_path}: {e}") mapping = {} try: with open(text_file, "r", encoding="utf-8") as fp: lines = [ln.strip() for ln in fp if ln.strip()] except Exception as e: raise RuntimeError(f"Error reading {text_file}: {e}") for ln in tqdm(lines, desc="Embedding articles"): if ln not in mapping: try: vec = self.text_to_embedding(ln) mapping[ln] = vec.cpu().numpy() except Exception as e: logger.error(f"Embedding failed for article: {e}") with atomic_write(save_path, mode='wb', overwrite=True) as f: pickle.dump(mapping, f, protocol=pickle.HIGHEST_PROTOCOL) return mapping def get_memory_usage(self): """Get current GPU memory usage if available.""" if torch.cuda.is_available(): return { 'allocated': torch.cuda.memory_allocated() / 1024 ** 3, # GB 'reserved': torch.cuda.memory_reserved() / 1024 ** 3, # GB } return {'allocated': 0, 'reserved': 0} # Alias for compatibility TextEmbedder = SimpleSentenceEmbedder