import os import argparse import numpy as np from tqdm import tqdm from pycocotools.coco import COCO from PIL import Image import torch from embedder import ImageEmbedder, DEVICE from text_embedder import TextEmbedder def parse_args(): p = argparse.ArgumentParser(description="COCO retrieval benchmark") p.add_argument("--coco-annotation-json", required=True, help="path to captions_val2017.json") p.add_argument("--coco-image-dir", required=True, help="path to val2017/ folder") p.add_argument("--llava-ckpt", required=True, help="path to pytorch_model-00003-of-00003.bin") p.add_argument("--proj-model", default="openai/clip-vit-large-patch14-336") p.add_argument("--proj-dim", type=int, default=5120) p.add_argument("--text-model", required=True) p.add_argument("--text-tokenizer", required=True) p.add_argument("--batch-size", type=int, default=64) return p.parse_args() def load_coco(ann_file): coco = COCO(ann_file) img_ids = coco.getImgIds() img2caps = {iid: [ann["caption"] for ann in coco.loadAnns(coco.getAnnIds(imgIds=iid))] for iid in img_ids} return coco, img_ids, img2caps def compute_image_embeddings(embedder, coco, img_ids, img_dir, bs): all_embs = [] for i in tqdm(range(0, len(img_ids), bs), desc="Images"): batch = img_ids[i:i + bs] imgs = [] for iid in batch: fn = coco.loadImgs(iid)[0]["file_name"] imgs.append(Image.open(os.path.join(img_dir, fn)).convert("RGB")) with torch.no_grad(): embs = torch.stack([embedder.image_to_embedding(im) for im in imgs]) all_embs.append(embs.cpu().numpy()) return np.vstack(all_embs) def compute_text_embeddings(embedder, captions, bs): all_embs = [] for i in tqdm(range(0, len(captions), bs), desc="Texts"): batch = captions[i:i + bs] with torch.no_grad(): embs = embedder.text_to_embedding(batch) all_embs.append(embs.cpu().numpy()) return np.vstack(all_embs) def compute_metrics(ranks, gt, K=(1, 5, 10)): R, meds = {k: 0 for k in K}, [] for idx, rank_list in enumerate(ranks): targets = gt[idx] best = min(int(np.where(rank_list == t)[0][0]) for t in targets) meds.append(best + 1) for k in K: if best < k: R[k] += 1 n = len(ranks) recall = {k: R[k] / n for k in K} return recall, np.median(meds) def evaluate(img_embs, txt_embs, img2caps, img_ids): sims = img_embs @ txt_embs.T # Image→Text img2cap = {} offset = 0 for i, iid in enumerate(img_ids): cnt = len(img2caps[iid]) img2cap[i] = list(range(offset, offset + cnt)) offset += cnt ranks_i2t = np.argsort(-sims, axis=1) R_i2t, med_i2t = compute_metrics(ranks_i2t, img2cap) print("Image→Text R@1,5,10:", [f"{R_i2t[k] * 100:.2f}%" for k in (1, 5, 10)]) print("Image→Text Median Rank:", med_i2t) # Text→Image cap2img, offset = {}, 0 for i, iid in enumerate(img_ids): for _ in img2caps[iid]: cap2img[offset] = i offset += 1 ranks_t2i = np.argsort(-sims.T, axis=1) gt = {idx: [cap2img[idx]] for idx in range(len(ranks_t2i))} R_t2i, med_t2i = compute_metrics(ranks_t2i, gt) print("Text→Image R@1,5,10:", [f"{R_t2i[k] * 100:.2f}%" for k in (1, 5, 10)]) print("Text→Image Median Rank:", med_t2i) def main(): args = parse_args() # 1) Load COCO coco, img_ids, img2caps = load_coco(args.coco_annotation_json) captions = [c for iid in img_ids for c in img2caps[iid]] # 2) Build embedders img_embedder = ImageEmbedder( vision_model_name=args.proj_model, proj_out_dim=args.proj_dim, llava_ckpt_path=args.llava_ckpt ) txt_embedder = TextEmbedder( model_name=args.text_model, tokenizer_name=args.text_tokenizer ) # 3) Compute embeddings img_vectors = compute_image_embeddings( img_embedder, coco, img_ids, args.coco_image_dir, args.batch_size ) txt_vectors = compute_text_embeddings( txt_embedder, captions, args.batch_size ) # 4) Evaluate Retrieval evaluate(img_vectors, txt_vectors, img2caps, img_ids) if __name__ == "__main__": main()