import pickle import hashlib import logging from pathlib import Path from typing import Dict, List, Union import torch import torch.nn as nn import numpy as np from transformers import CLIPTokenizer, CLIPTextModel from cachetools import LRUCache from tqdm import tqdm from atomicwrites import atomic_write # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Device & defaults DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Use CLIP which is specifically designed for text-image alignment DEFAULT_MODEL = "openai/clip-vit-base-patch32" OUTPUT_DIM = 4096 # Match image embeddings dimension CHUNK_BATCH_SIZE = 16 ARTICLE_CACHE_SIZE = 1024 # Persistence paths EMBEDDINGS_DIR = Path(__file__).parent / "embeddings" TEXT_EMB_PATH = EMBEDDINGS_DIR / "text_embeddings.pkl" class CLIPTextEmbedder: """ Text embedder using CLIP's text encoder, optimized for text-to-image retrieval with 4096-dimensional output to match image embeddings. """ def __init__( self, model_name: str = DEFAULT_MODEL, output_dim: int = OUTPUT_DIM, chunk_batch_size: int = CHUNK_BATCH_SIZE, ): logger.info(f"Initializing CLIPTextEmbedder with {model_name}") # Load CLIP text encoder self.tokenizer = CLIPTokenizer.from_pretrained(model_name) self.text_encoder = CLIPTextModel.from_pretrained(model_name).to(DEVICE).eval() # Freeze the text encoder for p in self.text_encoder.parameters(): p.requires_grad = False # Get CLIP's text embedding dimension (usually 512) clip_dim = self.text_encoder.config.hidden_size # Enhanced projector to match 4096-D image embeddings self.projector = nn.Sequential( nn.Linear(clip_dim, output_dim // 2), nn.LayerNorm(output_dim // 2), nn.ReLU(), nn.Dropout(0.1), nn.Linear(output_dim // 2, output_dim), nn.LayerNorm(output_dim), nn.ReLU(), nn.Linear(output_dim, output_dim), ).to(DEVICE) # Initialize projector with Xavier initialization for layer in self.projector: if isinstance(layer, nn.Linear): nn.init.xavier_uniform_(layer.weight) nn.init.zeros_(layer.bias) self.article_cache = LRUCache(maxsize=ARTICLE_CACHE_SIZE) self.chunk_batch_size = chunk_batch_size # Ensure embeddings dir EMBEDDINGS_DIR.mkdir(parents=True, exist_ok=True) logger.info(f"CLIPTextEmbedder ready: {clip_dim}D -> {output_dim}D on {DEVICE}") def text_to_embedding(self, text: Union[str, List[str]]) -> torch.Tensor: """ Embed text using CLIP's text encoder + projector Args: text: Single string or list of strings Returns: torch.Tensor: Normalized embeddings of shape (D,) for single text or (B, D) for batch of texts """ # Handle single text vs batch if isinstance(text, str): texts = [text] single_text = True else: texts = text single_text = False # Check cache for single text if single_text and texts[0] in self.article_cache: return self.article_cache[texts[0]] # Tokenize and encode with CLIP inputs = self.tokenizer( texts, return_tensors="pt", truncation=True, padding=True, max_length=77 # CLIP's max length ).to(DEVICE) with torch.no_grad(): # Get CLIP text features text_outputs = self.text_encoder(**inputs) # Use pooled output (CLS token equivalent) clip_features = text_outputs.pooler_output # [batch_size, clip_dim] # Normalize CLIP features clip_features = nn.functional.normalize(clip_features, p=2, dim=1) # Project to target dimension projected = self.projector(clip_features) # Normalize final output projected = nn.functional.normalize(projected, p=2, dim=1) # Cache single text result if single_text: result = projected.squeeze(0).cpu() self.article_cache[texts[0]] = result return result else: return projected.cpu() # Return batch def embed_texts_from_file( self, text_file: Union[str, Path], save_path: Union[str, Path] = TEXT_EMB_PATH ) -> Dict[str, np.ndarray]: """ Read newline-separated texts and embed them """ save_path = Path(save_path) save_path.parent.mkdir(parents=True, exist_ok=True) # Load existing embeddings try: with open(save_path, "rb") as fp: mapping: Dict[str, np.ndarray] = pickle.load(fp) logger.info(f"Loaded {len(mapping)} existing text embeddings") except FileNotFoundError: mapping = {} except Exception as e: logger.warning(f"Failed loading {save_path}: {e}") mapping = {} # Read text file try: with open(text_file, "r", encoding="utf-8") as fp: lines = [ln.strip() for ln in fp if ln.strip()] except Exception as e: raise RuntimeError(f"Error reading {text_file}: {e}") logger.info(f"Processing {len(lines)} texts from {text_file}") # Process in batches new_texts = [ln for ln in lines if ln not in mapping] logger.info(f"Embedding {len(new_texts)} new texts") if not new_texts: logger.info("All texts already embedded!") return mapping # Process in batches for i in tqdm(range(0, len(new_texts), self.chunk_batch_size), desc="Embedding texts"): batch_texts = new_texts[i:i + self.chunk_batch_size] try: # Get batch embeddings batch_embeddings = self.text_to_embedding(batch_texts) # Store in mapping for text, embedding in zip(batch_texts, batch_embeddings): mapping[text] = embedding.numpy() except Exception as e: logger.error(f"Embedding failed for batch starting at {i}: {e}") # Process individually as fallback for text in batch_texts: try: vec = self.text_to_embedding(text) mapping[text] = vec.numpy() except Exception as e2: logger.error(f"Individual embedding failed for text: {e2}") # Save updated mappings with atomic_write(save_path, mode='wb', overwrite=True) as f: pickle.dump(mapping, f, protocol=pickle.HIGHEST_PROTOCOL) logger.info(f"Saved {len(mapping)} text embeddings to {save_path}") return mapping def embed_single_text(self, text: str) -> np.ndarray: """Embed a single text and return as numpy array""" embedding = self.text_to_embedding(text) return embedding.numpy() def clear_cache(self): """Clear the text cache""" self.article_cache.clear() logger.info("Text cache cleared") def get_memory_usage(self): """Get current GPU memory usage if available.""" if torch.cuda.is_available(): return { 'allocated': torch.cuda.memory_allocated() / 1024 ** 3, # GB 'reserved': torch.cuda.memory_reserved() / 1024 ** 3, # GB } return {'allocated': 0, 'reserved': 0} # Alias for compatibility TextEmbedder = CLIPTextEmbedder def main(): import argparse parser = argparse.ArgumentParser("CLIP Text Embedder for Image Search") parser.add_argument("--text", type=str, help="Single text to embed") parser.add_argument("--file", type=str, help="Path to newline-delimited texts") parser.add_argument( "--out", type=str, default=str(TEXT_EMB_PATH), help="Pickle path for saving embeddings" ) parser.add_argument( "--batch-size", type=int, default=CHUNK_BATCH_SIZE, help="Batch size for processing" ) parser.add_argument( "--model", type=str, default=DEFAULT_MODEL, help="CLIP model name to use" ) parser.add_argument( "--dim", type=int, default=OUTPUT_DIM, help="Output embedding dimension" ) args = parser.parse_args() embedder = CLIPTextEmbedder( model_name=args.model, output_dim=args.dim, chunk_batch_size=args.batch_size, ) # Show memory usage mem = embedder.get_memory_usage() logger.info(f"GPU Memory - Allocated: {mem['allocated']:.2f}GB, Reserved: {mem['reserved']:.2f}GB") # Single-text mode if args.text: vec = embedder.embed_single_text(args.text) print(f"Text: {args.text}") print(f"Shape: {vec.shape}") print(f"Norm: {np.linalg.norm(vec):.4f}") print(f"Sample values: {vec[:10]}") # Save to file save_path = Path(args.out) save_path.parent.mkdir(parents=True, exist_ok=True) try: with open(save_path, "rb") as fp: mapping = pickle.load(fp) except (FileNotFoundError, pickle.PickleError): mapping = {} mapping[args.text] = vec with atomic_write(save_path, mode='wb', overwrite=True) as f: pickle.dump(mapping, f, protocol=pickle.HIGHEST_PROTOCOL) logger.info(f"Saved text embedding to {save_path}") # Batch-file mode if args.file: embedder.embed_texts_from_file(args.file, save_path=args.out) logger.info(f"Completed batch embedding from {args.file}") if not args.text and not args.file: # Demo mode demo_texts = [ "a cat sleeping on a bed", "a dog running in the park", "beautiful sunset over mountains", "red sports car on highway" ] print("Demo: Embedding sample texts...") for text in demo_texts: vec = embedder.embed_single_text(text) print(f"'{text}' -> shape {vec.shape}, norm {np.linalg.norm(vec):.4f}") if __name__ == "__main__": main()