37 lines
1.4 KiB
Python
37 lines
1.4 KiB
Python
import numpy as np
|
|
import pickle
|
|
from PIL import Image
|
|
from sklearn.metrics.pairwise import cosine_similarity
|
|
|
|
from embedder import ImageEmbedder
|
|
|
|
# ——— Load stored embeddings & mapping ———
|
|
vecs = np.load("processed_images/image_vectors.npy") # (N, D)
|
|
with open("processed_images/index_to_file.pkl", "rb") as f:
|
|
idx2file = pickle.load(f) # dict: idx → filepath
|
|
|
|
# ——— Specify query image ———
|
|
query_path = "datasets/coco/val2017/1140002154.jpg"
|
|
|
|
# ——— Embed the query image ———
|
|
embedder = ImageEmbedder(
|
|
vision_model_name="openai/clip-vit-large-patch14-336",
|
|
proj_out_dim=5120,
|
|
llava_ckpt_path="datasets/pytorch_model-00003-of-00003.bin"
|
|
)
|
|
img = Image.open(query_path).convert("RGB")
|
|
q_vec = embedder.image_to_embedding(img).cpu().numpy() # (D,)
|
|
|
|
# ——— Compute similarities & retrieve top-k ———
|
|
# (you can normalize if you built your DB with inner-product indexing)
|
|
# vecs_norm = vecs / np.linalg.norm(vecs, axis=1, keepdims=True)
|
|
# q_vec_norm = q_vec / np.linalg.norm(q_vec)
|
|
# sims = cosine_similarity(q_vec_norm.reshape(1, -1), vecs_norm).flatten()
|
|
sims = cosine_similarity(q_vec.reshape(1, -1), vecs).flatten()
|
|
top5 = sims.argsort()[-5:][::-1]
|
|
|
|
# ——— Print out the results ———
|
|
print(f"Query image: {query_path}\n")
|
|
for rank, idx in enumerate(top5, 1):
|
|
print(f"{rank:>2}. {idx2file[idx]} (score: {sims[idx]:.4f})")
|