text2img/examples/search.py
2025-06-18 16:07:02 +08:00

104 lines
3.3 KiB
Python

import numpy as np
import pickle
from pathlib import Path
import torch
import torch.nn as nn
from sentence_transformers import SentenceTransformer
IMG_EMB_PKL = Path("E:/GiteaRepo/text2img/embeddings/image_embeddings.pkl")
TOP_K = 5
query_text = """orange cat"""
def get_text_embedding(text, model, projector, device):
"""Get text embedding using sentence transformer + projector"""
with torch.no_grad():
# Get sentence embedding
sentence_emb = model.encode([text], convert_to_tensor=True, device=device)
sentence_emb = nn.functional.normalize(sentence_emb, p=2, dim=1)
# Project to 4096D
projected = projector(sentence_emb)
projected = nn.functional.normalize(projected, p=2, dim=1)
return projected.squeeze(0).cpu().numpy()
def main():
print(f"Loading image embeddings from {IMG_EMB_PKL}")
# 1) Load image embeddings
try:
with open(IMG_EMB_PKL, "rb") as f:
data = pickle.load(f)
except FileNotFoundError:
print(f"Error: Could not find {IMG_EMB_PKL}")
return
paths = np.array(list(data.keys()))
vecs = np.stack(list(data.values())) # shape (N, 4096)
print(f"Loaded {len(paths)} image embeddings with shape {vecs.shape}")
# 2) Initialize text embedding model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
# Use CLIP-based sentence transformer for better text-image alignment
model_name = "sentence-transformers/clip-ViT-B-32-multilingual-v1"
print(f"Loading text model: {model_name}")
try:
model = SentenceTransformer(model_name, device=device)
model_dim = model.get_sentence_embedding_dimension()
print(f"Text model dimension: {model_dim}")
# Simple projector to match 4096D image embeddings
projector = nn.Sequential(
nn.Linear(model_dim, 4096),
nn.ReLU(),
nn.Linear(4096, 4096),
).to(device)
# Better initialization
for layer in projector:
if isinstance(layer, nn.Linear):
nn.init.xavier_uniform_(layer.weight)
nn.init.zeros_(layer.bias)
except Exception as e:
print(f"Error loading text model: {e}")
return
# 3) Embed the query
print(f"Embedding query: '{query_text}'")
try:
q = get_text_embedding(query_text, model, projector, device)
print(f"Query embedding shape: {q.shape}")
print(f"Query embedding norm: {np.linalg.norm(q):.4f}")
except Exception as e:
print(f"Error embedding query: {e}")
return
# 4) Cosine similarity
print("Computing similarities...")
scores = vecs @ q # Dot product since both are normalized
top_k_idx = np.argsort(-scores)[:TOP_K]
# 5) Display results
print(f"\nTop-{TOP_K} images for: {query_text!r}\n")
print("-" * 80)
for rank, idx in enumerate(top_k_idx, 1):
print(f"{rank:>2}. {paths[idx]:<60} score={scores[idx]:>7.4f}")
print("-" * 80)
# Show stats
print(f"\nStatistics:")
print(f"Mean similarity: {scores.mean():.4f}")
print(f"Max similarity: {scores.max():.4f}")
print(f"Min similarity: {scores.min():.4f}")
print(f"Std similarity: {scores.std():.4f}")
if __name__ == "__main__":
main()