import numpy as np import pickle from pathlib import Path import torch import torch.nn as nn from sentence_transformers import SentenceTransformer IMG_EMB_PKL = Path("E:/GiteaRepo/text2img/embeddings/image_embeddings.pkl") TOP_K = 5 query_text = """orange cat""" def get_text_embedding(text, model, projector, device): """Get text embedding using sentence transformer + projector""" with torch.no_grad(): # Get sentence embedding sentence_emb = model.encode([text], convert_to_tensor=True, device=device) sentence_emb = nn.functional.normalize(sentence_emb, p=2, dim=1) # Project to 4096D projected = projector(sentence_emb) projected = nn.functional.normalize(projected, p=2, dim=1) return projected.squeeze(0).cpu().numpy() def main(): print(f"Loading image embeddings from {IMG_EMB_PKL}") # 1) Load image embeddings try: with open(IMG_EMB_PKL, "rb") as f: data = pickle.load(f) except FileNotFoundError: print(f"Error: Could not find {IMG_EMB_PKL}") return paths = np.array(list(data.keys())) vecs = np.stack(list(data.values())) # shape (N, 4096) print(f"Loaded {len(paths)} image embeddings with shape {vecs.shape}") # 2) Initialize text embedding model device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Using device: {device}") # Use CLIP-based sentence transformer for better text-image alignment model_name = "sentence-transformers/clip-ViT-B-32-multilingual-v1" print(f"Loading text model: {model_name}") try: model = SentenceTransformer(model_name, device=device) model_dim = model.get_sentence_embedding_dimension() print(f"Text model dimension: {model_dim}") # Simple projector to match 4096D image embeddings projector = nn.Sequential( nn.Linear(model_dim, 4096), nn.ReLU(), nn.Linear(4096, 4096), ).to(device) # Better initialization for layer in projector: if isinstance(layer, nn.Linear): nn.init.xavier_uniform_(layer.weight) nn.init.zeros_(layer.bias) except Exception as e: print(f"Error loading text model: {e}") return # 3) Embed the query print(f"Embedding query: '{query_text}'") try: q = get_text_embedding(query_text, model, projector, device) print(f"Query embedding shape: {q.shape}") print(f"Query embedding norm: {np.linalg.norm(q):.4f}") except Exception as e: print(f"Error embedding query: {e}") return # 4) Cosine similarity print("Computing similarities...") scores = vecs @ q # Dot product since both are normalized top_k_idx = np.argsort(-scores)[:TOP_K] # 5) Display results print(f"\nTop-{TOP_K} images for: {query_text!r}\n") print("-" * 80) for rank, idx in enumerate(top_k_idx, 1): print(f"{rank:>2}. {paths[idx]:<60} score={scores[idx]:>7.4f}") print("-" * 80) # Show stats print(f"\nStatistics:") print(f"Mean similarity: {scores.mean():.4f}") print(f"Max similarity: {scores.max():.4f}") print(f"Min similarity: {scores.min():.4f}") print(f"Std similarity: {scores.std():.4f}") if __name__ == "__main__": main()