104 lines
3.3 KiB
Python
104 lines
3.3 KiB
Python
import numpy as np
|
|
import pickle
|
|
from pathlib import Path
|
|
import torch
|
|
import torch.nn as nn
|
|
from sentence_transformers import SentenceTransformer
|
|
|
|
IMG_EMB_PKL = Path("E:/GiteaRepo/text2img/embeddings/image_embeddings.pkl")
|
|
TOP_K = 5
|
|
query_text = """orange cat"""
|
|
|
|
|
|
def get_text_embedding(text, model, projector, device):
|
|
"""Get text embedding using sentence transformer + projector"""
|
|
with torch.no_grad():
|
|
# Get sentence embedding
|
|
sentence_emb = model.encode([text], convert_to_tensor=True, device=device)
|
|
sentence_emb = nn.functional.normalize(sentence_emb, p=2, dim=1)
|
|
|
|
# Project to 4096D
|
|
projected = projector(sentence_emb)
|
|
projected = nn.functional.normalize(projected, p=2, dim=1)
|
|
|
|
return projected.squeeze(0).cpu().numpy()
|
|
|
|
|
|
def main():
|
|
print(f"Loading image embeddings from {IMG_EMB_PKL}")
|
|
|
|
# 1) Load image embeddings
|
|
try:
|
|
with open(IMG_EMB_PKL, "rb") as f:
|
|
data = pickle.load(f)
|
|
except FileNotFoundError:
|
|
print(f"Error: Could not find {IMG_EMB_PKL}")
|
|
return
|
|
|
|
paths = np.array(list(data.keys()))
|
|
vecs = np.stack(list(data.values())) # shape (N, 4096)
|
|
print(f"Loaded {len(paths)} image embeddings with shape {vecs.shape}")
|
|
|
|
# 2) Initialize text embedding model
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
print(f"Using device: {device}")
|
|
|
|
# Use CLIP-based sentence transformer for better text-image alignment
|
|
model_name = "sentence-transformers/clip-ViT-B-32-multilingual-v1"
|
|
print(f"Loading text model: {model_name}")
|
|
|
|
try:
|
|
model = SentenceTransformer(model_name, device=device)
|
|
model_dim = model.get_sentence_embedding_dimension()
|
|
print(f"Text model dimension: {model_dim}")
|
|
|
|
# Simple projector to match 4096D image embeddings
|
|
projector = nn.Sequential(
|
|
nn.Linear(model_dim, 4096),
|
|
nn.ReLU(),
|
|
nn.Linear(4096, 4096),
|
|
).to(device)
|
|
|
|
# Better initialization
|
|
for layer in projector:
|
|
if isinstance(layer, nn.Linear):
|
|
nn.init.xavier_uniform_(layer.weight)
|
|
nn.init.zeros_(layer.bias)
|
|
|
|
except Exception as e:
|
|
print(f"Error loading text model: {e}")
|
|
return
|
|
|
|
# 3) Embed the query
|
|
print(f"Embedding query: '{query_text}'")
|
|
try:
|
|
q = get_text_embedding(query_text, model, projector, device)
|
|
print(f"Query embedding shape: {q.shape}")
|
|
print(f"Query embedding norm: {np.linalg.norm(q):.4f}")
|
|
except Exception as e:
|
|
print(f"Error embedding query: {e}")
|
|
return
|
|
|
|
# 4) Cosine similarity
|
|
print("Computing similarities...")
|
|
scores = vecs @ q # Dot product since both are normalized
|
|
top_k_idx = np.argsort(-scores)[:TOP_K]
|
|
|
|
# 5) Display results
|
|
print(f"\nTop-{TOP_K} images for: {query_text!r}\n")
|
|
print("-" * 80)
|
|
for rank, idx in enumerate(top_k_idx, 1):
|
|
print(f"{rank:>2}. {paths[idx]:<60} score={scores[idx]:>7.4f}")
|
|
print("-" * 80)
|
|
|
|
# Show stats
|
|
print(f"\nStatistics:")
|
|
print(f"Mean similarity: {scores.mean():.4f}")
|
|
print(f"Max similarity: {scores.max():.4f}")
|
|
print(f"Min similarity: {scores.min():.4f}")
|
|
print(f"Std similarity: {scores.std():.4f}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|