Img2Vec/text_embedder.py

307 lines
10 KiB
Python

import pickle
import hashlib
import logging
from pathlib import Path
from typing import Dict, List, Union
import torch
import torch.nn as nn
import numpy as np
from transformers import CLIPTokenizer, CLIPTextModel
from cachetools import LRUCache
from tqdm import tqdm
from atomicwrites import atomic_write
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Device & defaults
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Use CLIP which is specifically designed for text-image alignment
DEFAULT_MODEL = "openai/clip-vit-base-patch32"
OUTPUT_DIM = 4096 # Match image embeddings dimension
CHUNK_BATCH_SIZE = 16
ARTICLE_CACHE_SIZE = 1024
# Persistence paths
EMBEDDINGS_DIR = Path(__file__).parent / "embeddings"
TEXT_EMB_PATH = EMBEDDINGS_DIR / "text_embeddings.pkl"
class CLIPTextEmbedder:
"""
Text embedder using CLIP's text encoder, optimized for text-to-image
retrieval with 4096-dimensional output to match image embeddings.
"""
def __init__(
self,
model_name: str = DEFAULT_MODEL,
output_dim: int = OUTPUT_DIM,
chunk_batch_size: int = CHUNK_BATCH_SIZE,
):
logger.info(f"Initializing CLIPTextEmbedder with {model_name}")
# Load CLIP text encoder
self.tokenizer = CLIPTokenizer.from_pretrained(model_name)
self.text_encoder = CLIPTextModel.from_pretrained(model_name).to(DEVICE).eval()
# Freeze the text encoder
for p in self.text_encoder.parameters():
p.requires_grad = False
# Get CLIP's text embedding dimension (usually 512)
clip_dim = self.text_encoder.config.hidden_size
# Enhanced projector to match 4096-D image embeddings
self.projector = nn.Sequential(
nn.Linear(clip_dim, output_dim // 2),
nn.LayerNorm(output_dim // 2),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(output_dim // 2, output_dim),
nn.LayerNorm(output_dim),
nn.ReLU(),
nn.Linear(output_dim, output_dim),
).to(DEVICE)
# Initialize projector with Xavier initialization
for layer in self.projector:
if isinstance(layer, nn.Linear):
nn.init.xavier_uniform_(layer.weight)
nn.init.zeros_(layer.bias)
self.article_cache = LRUCache(maxsize=ARTICLE_CACHE_SIZE)
self.chunk_batch_size = chunk_batch_size
# Ensure embeddings dir
EMBEDDINGS_DIR.mkdir(parents=True, exist_ok=True)
logger.info(f"CLIPTextEmbedder ready: {clip_dim}D -> {output_dim}D on {DEVICE}")
def text_to_embedding(self, text: Union[str, List[str]]) -> torch.Tensor:
"""
Embed text using CLIP's text encoder + projector
Args:
text: Single string or list of strings
Returns:
torch.Tensor: Normalized embeddings of shape (D,) for single text
or (B, D) for batch of texts
"""
# Handle single text vs batch
if isinstance(text, str):
texts = [text]
single_text = True
else:
texts = text
single_text = False
# Check cache for single text
if single_text and texts[0] in self.article_cache:
return self.article_cache[texts[0]]
# Tokenize and encode with CLIP
inputs = self.tokenizer(
texts,
return_tensors="pt",
truncation=True,
padding=True,
max_length=77 # CLIP's max length
).to(DEVICE)
with torch.no_grad():
# Get CLIP text features
text_outputs = self.text_encoder(**inputs)
# Use pooled output (CLS token equivalent)
clip_features = text_outputs.pooler_output # [batch_size, clip_dim]
# Normalize CLIP features
clip_features = nn.functional.normalize(clip_features, p=2, dim=1)
# Project to target dimension
projected = self.projector(clip_features)
# Normalize final output
projected = nn.functional.normalize(projected, p=2, dim=1)
# Cache single text result
if single_text:
result = projected.squeeze(0).cpu()
self.article_cache[texts[0]] = result
return result
else:
return projected.cpu() # Return batch
def embed_texts_from_file(
self,
text_file: Union[str, Path],
save_path: Union[str, Path] = TEXT_EMB_PATH
) -> Dict[str, np.ndarray]:
"""
Read newline-separated texts and embed them
"""
save_path = Path(save_path)
save_path.parent.mkdir(parents=True, exist_ok=True)
# Load existing embeddings
try:
with open(save_path, "rb") as fp:
mapping: Dict[str, np.ndarray] = pickle.load(fp)
logger.info(f"Loaded {len(mapping)} existing text embeddings")
except FileNotFoundError:
mapping = {}
except Exception as e:
logger.warning(f"Failed loading {save_path}: {e}")
mapping = {}
# Read text file
try:
with open(text_file, "r", encoding="utf-8") as fp:
lines = [ln.strip() for ln in fp if ln.strip()]
except Exception as e:
raise RuntimeError(f"Error reading {text_file}: {e}")
logger.info(f"Processing {len(lines)} texts from {text_file}")
# Process in batches
new_texts = [ln for ln in lines if ln not in mapping]
logger.info(f"Embedding {len(new_texts)} new texts")
if not new_texts:
logger.info("All texts already embedded!")
return mapping
# Process in batches
for i in tqdm(range(0, len(new_texts), self.chunk_batch_size), desc="Embedding texts"):
batch_texts = new_texts[i:i + self.chunk_batch_size]
try:
# Get batch embeddings
batch_embeddings = self.text_to_embedding(batch_texts)
# Store in mapping
for text, embedding in zip(batch_texts, batch_embeddings):
mapping[text] = embedding.numpy()
except Exception as e:
logger.error(f"Embedding failed for batch starting at {i}: {e}")
# Process individually as fallback
for text in batch_texts:
try:
vec = self.text_to_embedding(text)
mapping[text] = vec.numpy()
except Exception as e2:
logger.error(f"Individual embedding failed for text: {e2}")
# Save updated mappings
with atomic_write(save_path, mode='wb', overwrite=True) as f:
pickle.dump(mapping, f, protocol=pickle.HIGHEST_PROTOCOL)
logger.info(f"Saved {len(mapping)} text embeddings to {save_path}")
return mapping
def embed_single_text(self, text: str) -> np.ndarray:
"""Embed a single text and return as numpy array"""
embedding = self.text_to_embedding(text)
return embedding.numpy()
def clear_cache(self):
"""Clear the text cache"""
self.article_cache.clear()
logger.info("Text cache cleared")
def get_memory_usage(self):
"""Get current GPU memory usage if available."""
if torch.cuda.is_available():
return {
'allocated': torch.cuda.memory_allocated() / 1024 ** 3, # GB
'reserved': torch.cuda.memory_reserved() / 1024 ** 3, # GB
}
return {'allocated': 0, 'reserved': 0}
# Alias for compatibility
TextEmbedder = CLIPTextEmbedder
def main():
import argparse
parser = argparse.ArgumentParser("CLIP Text Embedder for Image Search")
parser.add_argument("--text", type=str, help="Single text to embed")
parser.add_argument("--file", type=str, help="Path to newline-delimited texts")
parser.add_argument(
"--out", type=str, default=str(TEXT_EMB_PATH),
help="Pickle path for saving embeddings"
)
parser.add_argument(
"--batch-size", type=int, default=CHUNK_BATCH_SIZE,
help="Batch size for processing"
)
parser.add_argument(
"--model", type=str, default=DEFAULT_MODEL,
help="CLIP model name to use"
)
parser.add_argument(
"--dim", type=int, default=OUTPUT_DIM,
help="Output embedding dimension"
)
args = parser.parse_args()
embedder = CLIPTextEmbedder(
model_name=args.model,
output_dim=args.dim,
chunk_batch_size=args.batch_size,
)
# Show memory usage
mem = embedder.get_memory_usage()
logger.info(f"GPU Memory - Allocated: {mem['allocated']:.2f}GB, Reserved: {mem['reserved']:.2f}GB")
# Single-text mode
if args.text:
vec = embedder.embed_single_text(args.text)
print(f"Text: {args.text}")
print(f"Shape: {vec.shape}")
print(f"Norm: {np.linalg.norm(vec):.4f}")
print(f"Sample values: {vec[:10]}")
# Save to file
save_path = Path(args.out)
save_path.parent.mkdir(parents=True, exist_ok=True)
try:
with open(save_path, "rb") as fp:
mapping = pickle.load(fp)
except (FileNotFoundError, pickle.PickleError):
mapping = {}
mapping[args.text] = vec
with atomic_write(save_path, mode='wb', overwrite=True) as f:
pickle.dump(mapping, f, protocol=pickle.HIGHEST_PROTOCOL)
logger.info(f"Saved text embedding to {save_path}")
# Batch-file mode
if args.file:
embedder.embed_texts_from_file(args.file, save_path=args.out)
logger.info(f"Completed batch embedding from {args.file}")
if not args.text and not args.file:
# Demo mode
demo_texts = [
"a cat sleeping on a bed",
"a dog running in the park",
"beautiful sunset over mountains",
"red sports car on highway"
]
print("Demo: Embedding sample texts...")
for text in demo_texts:
vec = embedder.embed_single_text(text)
print(f"'{text}' -> shape {vec.shape}, norm {np.linalg.norm(vec):.4f}")
if __name__ == "__main__":
main()