259 lines
9.5 KiB
Python
259 lines
9.5 KiB
Python
import pickle
|
|
import logging
|
|
from pathlib import Path
|
|
from typing import Dict, List, Union
|
|
|
|
import numpy as np
|
|
import torch
|
|
import torch.nn as nn
|
|
from PIL import Image, UnidentifiedImageError
|
|
from transformers import CLIPImageProcessor, CLIPVisionModel
|
|
from tqdm import tqdm
|
|
from atomicwrites import atomic_write
|
|
|
|
# Configure logging
|
|
logging.basicConfig(level=logging.INFO)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Device & defaults
|
|
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
DEFAULT_VISION_MODEL = "openai/clip-vit-large-patch14-336"
|
|
PROJECTOR_WEIGHTS_PATH = (
|
|
Path(__file__).parent.parent
|
|
/ "models"
|
|
/ "llava-v1.5-mlp2x-336px-vicuna-7b"
|
|
/ "mm_projector.bin"
|
|
)
|
|
OUTPUT_DIM = 4096
|
|
DEFAULT_BATCH_SIZE = 32
|
|
|
|
# Default output path
|
|
EMBED_DIR = Path(__file__).parent.parent / "embeddings"
|
|
IMAGE_EMB_PATH = EMBED_DIR / "image_embeddings.pkl"
|
|
|
|
|
|
class ImageEmbedder:
|
|
"""
|
|
Frozen CLIP vision encoder + pretrained 2-layer MLP projector.
|
|
Uses PATCH-LEVEL features only (mean of patch embeddings) following
|
|
LLaVA's configuration: mm_vision_select_feature="patch".
|
|
Provides single-image and folder-wide embedding, with options to
|
|
print only or append to a persistent pickle store via atomic writes.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
vision_model_name: str = DEFAULT_VISION_MODEL,
|
|
projector_weights_path: Union[str, Path] = PROJECTOR_WEIGHTS_PATH,
|
|
output_dim: int = OUTPUT_DIM,
|
|
):
|
|
# 1) Setup CLIP processor + model
|
|
self.processor = CLIPImageProcessor.from_pretrained(vision_model_name)
|
|
self.vision_model = (
|
|
CLIPVisionModel
|
|
.from_pretrained(vision_model_name, output_hidden_states=True)
|
|
.to(DEVICE)
|
|
.eval()
|
|
)
|
|
for p in self.vision_model.parameters():
|
|
p.requires_grad = False
|
|
|
|
# 2) Build MLP projector - PATCH ONLY (1024 input, not 2048)
|
|
# Following LLaVA config: mm_vision_select_feature="patch", mm_hidden_size=1024
|
|
hidden_dim = self.vision_model.config.hidden_size # 1024, not 1024*2
|
|
self.projector = nn.Sequential(
|
|
nn.Linear(hidden_dim, output_dim), # 1024 -> 4096
|
|
nn.GELU(),
|
|
nn.Linear(output_dim, output_dim), # 4096 -> 4096
|
|
).to(DEVICE)
|
|
|
|
# 3) Load projector weights
|
|
ckpt = torch.load(projector_weights_path, map_location=DEVICE)
|
|
state = ckpt.get("projector_state_dict", ckpt)
|
|
# Strip prefix
|
|
clean_state = {
|
|
(k[len("model.mm_projector."):] if k.startswith("model.mm_projector.") else k): v
|
|
for k, v in state.items()
|
|
}
|
|
# Load with warning on mismatch
|
|
missing, unexpected = self.projector.load_state_dict(clean_state, strict=False)
|
|
if missing or unexpected:
|
|
logger.info(f"Loaded projector weights with missing={missing}, unexpected={unexpected}")
|
|
|
|
def image_to_embedding(self, image: Image.Image) -> torch.Tensor:
|
|
"""
|
|
Embed a single PIL.Image → (output_dim,) L2-normalized vector.
|
|
Uses PATCH-LEVEL features only (mean of patches, no CLS token).
|
|
"""
|
|
try:
|
|
inputs = self.processor(images=image, return_tensors="pt")
|
|
inputs = {k: v.to(DEVICE) for k, v in inputs.items()}
|
|
except Exception as e:
|
|
logger.exception(f"Preprocessing failed: {e}")
|
|
raise
|
|
|
|
with torch.no_grad():
|
|
last_hidden = self.vision_model(**inputs).last_hidden_state
|
|
|
|
# PATCH-ONLY: Use mean of patch embeddings, drop CLS token
|
|
# last_hidden shape: (1, 1+num_patches, 1024)
|
|
# [:, 0, :] = CLS token (skip)
|
|
# [:, 1:, :] = patch tokens (use mean)
|
|
patch_mean = last_hidden[:, 1:, :].mean(dim=1) # (1, 1024)
|
|
combo = nn.functional.normalize(patch_mean, p=2, dim=1) # L2 normalize
|
|
|
|
proj = self.projector(combo) # (1, 4096)
|
|
emb = nn.functional.normalize(proj, p=2, dim=1) # L2 normalize output
|
|
return emb.squeeze(0) # → (OUTPUT_DIM,)
|
|
|
|
def embed_folder(
|
|
self,
|
|
folder_path: Union[str, Path],
|
|
save_path: Union[str, Path] = IMAGE_EMB_PATH,
|
|
extensions: List[str] = None,
|
|
batch_size: int = DEFAULT_BATCH_SIZE,
|
|
) -> Dict[str, np.ndarray]:
|
|
"""
|
|
Embed all images in a folder (recursively), skipping already-embedded
|
|
and corrupted files, saving after each batch.
|
|
Uses PATCH-LEVEL features only.
|
|
"""
|
|
if extensions is None:
|
|
extensions = [".jpg", ".jpeg", ".png", ".webp", ".bmp", ".tiff", ".gif"]
|
|
|
|
folder = Path(folder_path)
|
|
save_path = Path(save_path)
|
|
save_path.parent.mkdir(parents=True, exist_ok=True)
|
|
|
|
# Load existing embeddings
|
|
try:
|
|
with open(save_path, "rb") as fp:
|
|
mapping = pickle.load(fp)
|
|
except FileNotFoundError:
|
|
mapping = {}
|
|
except Exception as e:
|
|
logger.warning(f"Could not load embeddings: {e}")
|
|
mapping = {}
|
|
|
|
# Collect files to embed
|
|
all_files = [f for f in folder.rglob("*") if f.suffix.lower() in extensions]
|
|
to_process = [f for f in all_files if str(f) not in mapping]
|
|
if not to_process:
|
|
logger.info("No new images to embed.")
|
|
return mapping
|
|
|
|
logger.info(f"Embedding {len(to_process)} images (batch size {batch_size})")
|
|
for i in tqdm(range(0, len(to_process), batch_size), desc="Images"):
|
|
batch = to_process[i: i + batch_size]
|
|
imgs, paths = [], []
|
|
for f in batch:
|
|
try:
|
|
imgs.append(Image.open(f).convert("RGB"))
|
|
paths.append(f)
|
|
except UnidentifiedImageError:
|
|
logger.warning(f"Skipping corrupted: {f}")
|
|
except Exception as e:
|
|
logger.warning(f"Error loading {f}: {e}")
|
|
|
|
if not imgs:
|
|
continue
|
|
|
|
# Embed batch
|
|
inputs = self.processor(images=imgs, return_tensors="pt")
|
|
inputs = {k: v.to(DEVICE) for k, v in inputs.items()}
|
|
with torch.no_grad():
|
|
last_hidden = self.vision_model(**inputs).last_hidden_state
|
|
|
|
# PATCH-ONLY: Use mean of patches, drop CLS token
|
|
patch_means = last_hidden[:, 1:, :].mean(dim=1) # (batch_size, 1024)
|
|
combo = nn.functional.normalize(patch_means, p=2, dim=1)
|
|
proj = self.projector(combo) # (batch_size, 4096)
|
|
embs = nn.functional.normalize(proj, p=2, dim=1).cpu().detach().numpy()
|
|
|
|
# Update and save
|
|
for pth, emb in zip(paths, embs):
|
|
mapping[str(pth)] = emb
|
|
with atomic_write(save_path, mode="wb", overwrite=True) as fp:
|
|
pickle.dump(mapping, fp, protocol=pickle.HIGHEST_PROTOCOL)
|
|
|
|
logger.info(f"Done. Total embeddings: {len(mapping)}")
|
|
return mapping
|
|
|
|
def embed_image(
|
|
self,
|
|
image_path: Union[str, Path],
|
|
save_path: Union[str, Path] = IMAGE_EMB_PATH,
|
|
no_save: bool = False,
|
|
) -> torch.Tensor:
|
|
"""
|
|
Embed one image using PATCH-LEVEL features only:
|
|
- prints the vector
|
|
- unless no_save=True, also appends to the pickle at save_path
|
|
Returns the torch.Tensor.
|
|
"""
|
|
img_path = Path(image_path)
|
|
try:
|
|
img = Image.open(img_path).convert("RGB")
|
|
vec = self.image_to_embedding(img)
|
|
except Exception as e:
|
|
logger.error(f"Failed embedding {img_path}: {e}")
|
|
raise
|
|
|
|
# Print out the vector
|
|
vec_np = vec.cpu().detach().numpy()
|
|
print(f"{img_path} → shape={vec_np.shape}")
|
|
print(vec_np)
|
|
|
|
# Append to pickle unless disabled
|
|
if not no_save:
|
|
save_path = Path(save_path)
|
|
save_path.parent.mkdir(parents=True, exist_ok=True)
|
|
try:
|
|
with open(save_path, "rb") as fp:
|
|
mapping = pickle.load(fp)
|
|
except (FileNotFoundError, pickle.PickleError):
|
|
mapping = {}
|
|
|
|
mapping[str(img_path)] = vec_np
|
|
with atomic_write(save_path, overwrite=True) as fp:
|
|
fp.write(pickle.dumps(mapping, protocol=pickle.HIGHEST_PROTOCOL))
|
|
logger.info(f"Saved embedding for {img_path} to {save_path}")
|
|
|
|
return vec
|
|
|
|
|
|
def main():
|
|
import argparse
|
|
|
|
parser = argparse.ArgumentParser(__doc__)
|
|
parser.add_argument("--image", type=str, help="Path to a single image to embed")
|
|
parser.add_argument("--folder", type=str, help="Path to a folder of images")
|
|
parser.add_argument("--out", type=str, default=str(IMAGE_EMB_PATH),
|
|
help="Pickle to save embeddings (default: embeddings/image_embeddings.pkl)")
|
|
parser.add_argument("--batch-size", type=int, default=DEFAULT_BATCH_SIZE,
|
|
help="Batch size for folder embedding")
|
|
parser.add_argument("--no-save", action="store_true",
|
|
help="When embedding a single image, do not save to the pickle")
|
|
args = parser.parse_args()
|
|
|
|
embedder = ImageEmbedder()
|
|
|
|
if args.image:
|
|
embedder.embed_image(
|
|
image_path=args.image,
|
|
save_path=args.out,
|
|
no_save=args.no_save
|
|
)
|
|
elif args.folder:
|
|
embedder.embed_folder(
|
|
folder_path=args.folder,
|
|
save_path=args.out,
|
|
batch_size=args.batch_size
|
|
)
|
|
else:
|
|
parser.error("Please specify --image or --folder")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|