307 lines
10 KiB
Python
307 lines
10 KiB
Python
import pickle
|
|
import hashlib
|
|
import logging
|
|
from pathlib import Path
|
|
from typing import Dict, List, Union
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import numpy as np
|
|
from transformers import CLIPTokenizer, CLIPTextModel
|
|
from cachetools import LRUCache
|
|
from tqdm import tqdm
|
|
from atomicwrites import atomic_write
|
|
|
|
# Configure logging
|
|
logging.basicConfig(level=logging.INFO)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Device & defaults
|
|
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
# Use CLIP which is specifically designed for text-image alignment
|
|
DEFAULT_MODEL = "openai/clip-vit-base-patch32"
|
|
OUTPUT_DIM = 4096 # Match image embeddings dimension
|
|
CHUNK_BATCH_SIZE = 16
|
|
ARTICLE_CACHE_SIZE = 1024
|
|
|
|
# Persistence paths
|
|
EMBEDDINGS_DIR = Path(__file__).parent / "embeddings"
|
|
TEXT_EMB_PATH = EMBEDDINGS_DIR / "text_embeddings.pkl"
|
|
|
|
|
|
class CLIPTextEmbedder:
|
|
"""
|
|
Text embedder using CLIP's text encoder, optimized for text-to-image
|
|
retrieval with 4096-dimensional output to match image embeddings.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
model_name: str = DEFAULT_MODEL,
|
|
output_dim: int = OUTPUT_DIM,
|
|
chunk_batch_size: int = CHUNK_BATCH_SIZE,
|
|
):
|
|
logger.info(f"Initializing CLIPTextEmbedder with {model_name}")
|
|
|
|
# Load CLIP text encoder
|
|
self.tokenizer = CLIPTokenizer.from_pretrained(model_name)
|
|
self.text_encoder = CLIPTextModel.from_pretrained(model_name).to(DEVICE).eval()
|
|
|
|
# Freeze the text encoder
|
|
for p in self.text_encoder.parameters():
|
|
p.requires_grad = False
|
|
|
|
# Get CLIP's text embedding dimension (usually 512)
|
|
clip_dim = self.text_encoder.config.hidden_size
|
|
|
|
# Enhanced projector to match 4096-D image embeddings
|
|
self.projector = nn.Sequential(
|
|
nn.Linear(clip_dim, output_dim // 2),
|
|
nn.LayerNorm(output_dim // 2),
|
|
nn.ReLU(),
|
|
nn.Dropout(0.1),
|
|
nn.Linear(output_dim // 2, output_dim),
|
|
nn.LayerNorm(output_dim),
|
|
nn.ReLU(),
|
|
nn.Linear(output_dim, output_dim),
|
|
).to(DEVICE)
|
|
|
|
# Initialize projector with Xavier initialization
|
|
for layer in self.projector:
|
|
if isinstance(layer, nn.Linear):
|
|
nn.init.xavier_uniform_(layer.weight)
|
|
nn.init.zeros_(layer.bias)
|
|
|
|
self.article_cache = LRUCache(maxsize=ARTICLE_CACHE_SIZE)
|
|
self.chunk_batch_size = chunk_batch_size
|
|
|
|
# Ensure embeddings dir
|
|
EMBEDDINGS_DIR.mkdir(parents=True, exist_ok=True)
|
|
|
|
logger.info(f"CLIPTextEmbedder ready: {clip_dim}D -> {output_dim}D on {DEVICE}")
|
|
|
|
def text_to_embedding(self, text: Union[str, List[str]]) -> torch.Tensor:
|
|
"""
|
|
Embed text using CLIP's text encoder + projector
|
|
|
|
Args:
|
|
text: Single string or list of strings
|
|
|
|
Returns:
|
|
torch.Tensor: Normalized embeddings of shape (D,) for single text
|
|
or (B, D) for batch of texts
|
|
"""
|
|
# Handle single text vs batch
|
|
if isinstance(text, str):
|
|
texts = [text]
|
|
single_text = True
|
|
else:
|
|
texts = text
|
|
single_text = False
|
|
|
|
# Check cache for single text
|
|
if single_text and texts[0] in self.article_cache:
|
|
return self.article_cache[texts[0]]
|
|
|
|
# Tokenize and encode with CLIP
|
|
inputs = self.tokenizer(
|
|
texts,
|
|
return_tensors="pt",
|
|
truncation=True,
|
|
padding=True,
|
|
max_length=77 # CLIP's max length
|
|
).to(DEVICE)
|
|
|
|
with torch.no_grad():
|
|
# Get CLIP text features
|
|
text_outputs = self.text_encoder(**inputs)
|
|
# Use pooled output (CLS token equivalent)
|
|
clip_features = text_outputs.pooler_output # [batch_size, clip_dim]
|
|
|
|
# Normalize CLIP features
|
|
clip_features = nn.functional.normalize(clip_features, p=2, dim=1)
|
|
|
|
# Project to target dimension
|
|
projected = self.projector(clip_features)
|
|
# Normalize final output
|
|
projected = nn.functional.normalize(projected, p=2, dim=1)
|
|
|
|
# Cache single text result
|
|
if single_text:
|
|
result = projected.squeeze(0).cpu()
|
|
self.article_cache[texts[0]] = result
|
|
return result
|
|
else:
|
|
return projected.cpu() # Return batch
|
|
|
|
def embed_texts_from_file(
|
|
self,
|
|
text_file: Union[str, Path],
|
|
save_path: Union[str, Path] = TEXT_EMB_PATH
|
|
) -> Dict[str, np.ndarray]:
|
|
"""
|
|
Read newline-separated texts and embed them
|
|
"""
|
|
save_path = Path(save_path)
|
|
save_path.parent.mkdir(parents=True, exist_ok=True)
|
|
|
|
# Load existing embeddings
|
|
try:
|
|
with open(save_path, "rb") as fp:
|
|
mapping: Dict[str, np.ndarray] = pickle.load(fp)
|
|
logger.info(f"Loaded {len(mapping)} existing text embeddings")
|
|
except FileNotFoundError:
|
|
mapping = {}
|
|
except Exception as e:
|
|
logger.warning(f"Failed loading {save_path}: {e}")
|
|
mapping = {}
|
|
|
|
# Read text file
|
|
try:
|
|
with open(text_file, "r", encoding="utf-8") as fp:
|
|
lines = [ln.strip() for ln in fp if ln.strip()]
|
|
except Exception as e:
|
|
raise RuntimeError(f"Error reading {text_file}: {e}")
|
|
|
|
logger.info(f"Processing {len(lines)} texts from {text_file}")
|
|
|
|
# Process in batches
|
|
new_texts = [ln for ln in lines if ln not in mapping]
|
|
logger.info(f"Embedding {len(new_texts)} new texts")
|
|
|
|
if not new_texts:
|
|
logger.info("All texts already embedded!")
|
|
return mapping
|
|
|
|
# Process in batches
|
|
for i in tqdm(range(0, len(new_texts), self.chunk_batch_size), desc="Embedding texts"):
|
|
batch_texts = new_texts[i:i + self.chunk_batch_size]
|
|
|
|
try:
|
|
# Get batch embeddings
|
|
batch_embeddings = self.text_to_embedding(batch_texts)
|
|
|
|
# Store in mapping
|
|
for text, embedding in zip(batch_texts, batch_embeddings):
|
|
mapping[text] = embedding.numpy()
|
|
|
|
except Exception as e:
|
|
logger.error(f"Embedding failed for batch starting at {i}: {e}")
|
|
# Process individually as fallback
|
|
for text in batch_texts:
|
|
try:
|
|
vec = self.text_to_embedding(text)
|
|
mapping[text] = vec.numpy()
|
|
except Exception as e2:
|
|
logger.error(f"Individual embedding failed for text: {e2}")
|
|
|
|
# Save updated mappings
|
|
with atomic_write(save_path, mode='wb', overwrite=True) as f:
|
|
pickle.dump(mapping, f, protocol=pickle.HIGHEST_PROTOCOL)
|
|
|
|
logger.info(f"Saved {len(mapping)} text embeddings to {save_path}")
|
|
return mapping
|
|
|
|
def embed_single_text(self, text: str) -> np.ndarray:
|
|
"""Embed a single text and return as numpy array"""
|
|
embedding = self.text_to_embedding(text)
|
|
return embedding.numpy()
|
|
|
|
def clear_cache(self):
|
|
"""Clear the text cache"""
|
|
self.article_cache.clear()
|
|
logger.info("Text cache cleared")
|
|
|
|
def get_memory_usage(self):
|
|
"""Get current GPU memory usage if available."""
|
|
if torch.cuda.is_available():
|
|
return {
|
|
'allocated': torch.cuda.memory_allocated() / 1024 ** 3, # GB
|
|
'reserved': torch.cuda.memory_reserved() / 1024 ** 3, # GB
|
|
}
|
|
return {'allocated': 0, 'reserved': 0}
|
|
|
|
|
|
# Alias for compatibility
|
|
TextEmbedder = CLIPTextEmbedder
|
|
|
|
|
|
def main():
|
|
import argparse
|
|
|
|
parser = argparse.ArgumentParser("CLIP Text Embedder for Image Search")
|
|
parser.add_argument("--text", type=str, help="Single text to embed")
|
|
parser.add_argument("--file", type=str, help="Path to newline-delimited texts")
|
|
parser.add_argument(
|
|
"--out", type=str, default=str(TEXT_EMB_PATH),
|
|
help="Pickle path for saving embeddings"
|
|
)
|
|
parser.add_argument(
|
|
"--batch-size", type=int, default=CHUNK_BATCH_SIZE,
|
|
help="Batch size for processing"
|
|
)
|
|
parser.add_argument(
|
|
"--model", type=str, default=DEFAULT_MODEL,
|
|
help="CLIP model name to use"
|
|
)
|
|
parser.add_argument(
|
|
"--dim", type=int, default=OUTPUT_DIM,
|
|
help="Output embedding dimension"
|
|
)
|
|
args = parser.parse_args()
|
|
|
|
embedder = CLIPTextEmbedder(
|
|
model_name=args.model,
|
|
output_dim=args.dim,
|
|
chunk_batch_size=args.batch_size,
|
|
)
|
|
|
|
# Show memory usage
|
|
mem = embedder.get_memory_usage()
|
|
logger.info(f"GPU Memory - Allocated: {mem['allocated']:.2f}GB, Reserved: {mem['reserved']:.2f}GB")
|
|
|
|
# Single-text mode
|
|
if args.text:
|
|
vec = embedder.embed_single_text(args.text)
|
|
print(f"Text: {args.text}")
|
|
print(f"Shape: {vec.shape}")
|
|
print(f"Norm: {np.linalg.norm(vec):.4f}")
|
|
print(f"Sample values: {vec[:10]}")
|
|
|
|
# Save to file
|
|
save_path = Path(args.out)
|
|
save_path.parent.mkdir(parents=True, exist_ok=True)
|
|
try:
|
|
with open(save_path, "rb") as fp:
|
|
mapping = pickle.load(fp)
|
|
except (FileNotFoundError, pickle.PickleError):
|
|
mapping = {}
|
|
|
|
mapping[args.text] = vec
|
|
with atomic_write(save_path, mode='wb', overwrite=True) as f:
|
|
pickle.dump(mapping, f, protocol=pickle.HIGHEST_PROTOCOL)
|
|
logger.info(f"Saved text embedding to {save_path}")
|
|
|
|
# Batch-file mode
|
|
if args.file:
|
|
embedder.embed_texts_from_file(args.file, save_path=args.out)
|
|
logger.info(f"Completed batch embedding from {args.file}")
|
|
|
|
if not args.text and not args.file:
|
|
# Demo mode
|
|
demo_texts = [
|
|
"a cat sleeping on a bed",
|
|
"a dog running in the park",
|
|
"beautiful sunset over mountains",
|
|
"red sports car on highway"
|
|
]
|
|
|
|
print("Demo: Embedding sample texts...")
|
|
for text in demo_texts:
|
|
vec = embedder.embed_single_text(text)
|
|
print(f"'{text}' -> shape {vec.shape}, norm {np.linalg.norm(vec):.4f}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|