Img2Vec/benchmark_coco.py
2025-06-13 07:21:58 +08:00

137 lines
4.3 KiB
Python

import os
import argparse
import numpy as np
from tqdm import tqdm
from pycocotools.coco import COCO
from PIL import Image
import torch
from embedder import ImageEmbedder, DEVICE
from text_embedder import TextEmbedder
def parse_args():
p = argparse.ArgumentParser(description="COCO retrieval benchmark")
p.add_argument("--coco-annotation-json", required=True,
help="path to captions_val2017.json")
p.add_argument("--coco-image-dir", required=True,
help="path to val2017/ folder")
p.add_argument("--llava-ckpt", required=True,
help="path to pytorch_model-00003-of-00003.bin")
p.add_argument("--proj-model", default="openai/clip-vit-large-patch14-336")
p.add_argument("--proj-dim", type=int, default=5120)
p.add_argument("--text-model", required=True)
p.add_argument("--text-tokenizer", required=True)
p.add_argument("--batch-size", type=int, default=64)
return p.parse_args()
def load_coco(ann_file):
coco = COCO(ann_file)
img_ids = coco.getImgIds()
img2caps = {iid: [ann["caption"] for ann in coco.loadAnns(coco.getAnnIds(imgIds=iid))]
for iid in img_ids}
return coco, img_ids, img2caps
def compute_image_embeddings(embedder, coco, img_ids, img_dir, bs):
all_embs = []
for i in tqdm(range(0, len(img_ids), bs), desc="Images"):
batch = img_ids[i:i + bs]
imgs = []
for iid in batch:
fn = coco.loadImgs(iid)[0]["file_name"]
imgs.append(Image.open(os.path.join(img_dir, fn)).convert("RGB"))
with torch.no_grad():
embs = torch.stack([embedder.image_to_embedding(im) for im in imgs])
all_embs.append(embs.cpu().numpy())
return np.vstack(all_embs)
def compute_text_embeddings(embedder, captions, bs):
all_embs = []
for i in tqdm(range(0, len(captions), bs), desc="Texts"):
batch = captions[i:i + bs]
with torch.no_grad():
embs = embedder.text_to_embedding(batch)
all_embs.append(embs.cpu().numpy())
return np.vstack(all_embs)
def compute_metrics(ranks, gt, K=(1, 5, 10)):
R, meds = {k: 0 for k in K}, []
for idx, rank_list in enumerate(ranks):
targets = gt[idx]
best = min(int(np.where(rank_list == t)[0][0]) for t in targets)
meds.append(best + 1)
for k in K:
if best < k:
R[k] += 1
n = len(ranks)
recall = {k: R[k] / n for k in K}
return recall, np.median(meds)
def evaluate(img_embs, txt_embs, img2caps, img_ids):
sims = img_embs @ txt_embs.T
# Image→Text
img2cap = {}
offset = 0
for i, iid in enumerate(img_ids):
cnt = len(img2caps[iid])
img2cap[i] = list(range(offset, offset + cnt))
offset += cnt
ranks_i2t = np.argsort(-sims, axis=1)
R_i2t, med_i2t = compute_metrics(ranks_i2t, img2cap)
print("Image→Text R@1,5,10:", [f"{R_i2t[k] * 100:.2f}%" for k in (1, 5, 10)])
print("Image→Text Median Rank:", med_i2t)
# Text→Image
cap2img, offset = {}, 0
for i, iid in enumerate(img_ids):
for _ in img2caps[iid]:
cap2img[offset] = i
offset += 1
ranks_t2i = np.argsort(-sims.T, axis=1)
gt = {idx: [cap2img[idx]] for idx in range(len(ranks_t2i))}
R_t2i, med_t2i = compute_metrics(ranks_t2i, gt)
print("Text→Image R@1,5,10:", [f"{R_t2i[k] * 100:.2f}%" for k in (1, 5, 10)])
print("Text→Image Median Rank:", med_t2i)
def main():
args = parse_args()
# 1) Load COCO
coco, img_ids, img2caps = load_coco(args.coco_annotation_json)
captions = [c for iid in img_ids for c in img2caps[iid]]
# 2) Build embedders
img_embedder = ImageEmbedder(
vision_model_name=args.proj_model,
proj_out_dim=args.proj_dim,
llava_ckpt_path=args.llava_ckpt
)
txt_embedder = TextEmbedder(
model_name=args.text_model,
tokenizer_name=args.text_tokenizer
)
# 3) Compute embeddings
img_vectors = compute_image_embeddings(
img_embedder, coco, img_ids, args.coco_image_dir, args.batch_size
)
txt_vectors = compute_text_embeddings(
txt_embedder, captions, args.batch_size
)
# 4) Evaluate Retrieval
evaluate(img_vectors, txt_vectors, img2caps, img_ids)
if __name__ == "__main__":
main()