text2img/embed/image_embedder.py
2025-06-18 16:07:02 +08:00

259 lines
9.5 KiB
Python

import pickle
import logging
from pathlib import Path
from typing import Dict, List, Union
import numpy as np
import torch
import torch.nn as nn
from PIL import Image, UnidentifiedImageError
from transformers import CLIPImageProcessor, CLIPVisionModel
from tqdm import tqdm
from atomicwrites import atomic_write
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Device & defaults
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
DEFAULT_VISION_MODEL = "openai/clip-vit-large-patch14-336"
PROJECTOR_WEIGHTS_PATH = (
Path(__file__).parent.parent
/ "models"
/ "llava-v1.5-mlp2x-336px-vicuna-7b"
/ "mm_projector.bin"
)
OUTPUT_DIM = 4096
DEFAULT_BATCH_SIZE = 32
# Default output path
EMBED_DIR = Path(__file__).parent.parent / "embeddings"
IMAGE_EMB_PATH = EMBED_DIR / "image_embeddings.pkl"
class ImageEmbedder:
"""
Frozen CLIP vision encoder + pretrained 2-layer MLP projector.
Uses PATCH-LEVEL features only (mean of patch embeddings) following
LLaVA's configuration: mm_vision_select_feature="patch".
Provides single-image and folder-wide embedding, with options to
print only or append to a persistent pickle store via atomic writes.
"""
def __init__(
self,
vision_model_name: str = DEFAULT_VISION_MODEL,
projector_weights_path: Union[str, Path] = PROJECTOR_WEIGHTS_PATH,
output_dim: int = OUTPUT_DIM,
):
# 1) Setup CLIP processor + model
self.processor = CLIPImageProcessor.from_pretrained(vision_model_name)
self.vision_model = (
CLIPVisionModel
.from_pretrained(vision_model_name, output_hidden_states=True)
.to(DEVICE)
.eval()
)
for p in self.vision_model.parameters():
p.requires_grad = False
# 2) Build MLP projector - PATCH ONLY (1024 input, not 2048)
# Following LLaVA config: mm_vision_select_feature="patch", mm_hidden_size=1024
hidden_dim = self.vision_model.config.hidden_size # 1024, not 1024*2
self.projector = nn.Sequential(
nn.Linear(hidden_dim, output_dim), # 1024 -> 4096
nn.GELU(),
nn.Linear(output_dim, output_dim), # 4096 -> 4096
).to(DEVICE)
# 3) Load projector weights
ckpt = torch.load(projector_weights_path, map_location=DEVICE)
state = ckpt.get("projector_state_dict", ckpt)
# Strip prefix
clean_state = {
(k[len("model.mm_projector."):] if k.startswith("model.mm_projector.") else k): v
for k, v in state.items()
}
# Load with warning on mismatch
missing, unexpected = self.projector.load_state_dict(clean_state, strict=False)
if missing or unexpected:
logger.info(f"Loaded projector weights with missing={missing}, unexpected={unexpected}")
def image_to_embedding(self, image: Image.Image) -> torch.Tensor:
"""
Embed a single PIL.Image → (output_dim,) L2-normalized vector.
Uses PATCH-LEVEL features only (mean of patches, no CLS token).
"""
try:
inputs = self.processor(images=image, return_tensors="pt")
inputs = {k: v.to(DEVICE) for k, v in inputs.items()}
except Exception as e:
logger.exception(f"Preprocessing failed: {e}")
raise
with torch.no_grad():
last_hidden = self.vision_model(**inputs).last_hidden_state
# PATCH-ONLY: Use mean of patch embeddings, drop CLS token
# last_hidden shape: (1, 1+num_patches, 1024)
# [:, 0, :] = CLS token (skip)
# [:, 1:, :] = patch tokens (use mean)
patch_mean = last_hidden[:, 1:, :].mean(dim=1) # (1, 1024)
combo = nn.functional.normalize(patch_mean, p=2, dim=1) # L2 normalize
proj = self.projector(combo) # (1, 4096)
emb = nn.functional.normalize(proj, p=2, dim=1) # L2 normalize output
return emb.squeeze(0) # → (OUTPUT_DIM,)
def embed_folder(
self,
folder_path: Union[str, Path],
save_path: Union[str, Path] = IMAGE_EMB_PATH,
extensions: List[str] = None,
batch_size: int = DEFAULT_BATCH_SIZE,
) -> Dict[str, np.ndarray]:
"""
Embed all images in a folder (recursively), skipping already-embedded
and corrupted files, saving after each batch.
Uses PATCH-LEVEL features only.
"""
if extensions is None:
extensions = [".jpg", ".jpeg", ".png", ".webp", ".bmp", ".tiff", ".gif"]
folder = Path(folder_path)
save_path = Path(save_path)
save_path.parent.mkdir(parents=True, exist_ok=True)
# Load existing embeddings
try:
with open(save_path, "rb") as fp:
mapping = pickle.load(fp)
except FileNotFoundError:
mapping = {}
except Exception as e:
logger.warning(f"Could not load embeddings: {e}")
mapping = {}
# Collect files to embed
all_files = [f for f in folder.rglob("*") if f.suffix.lower() in extensions]
to_process = [f for f in all_files if str(f) not in mapping]
if not to_process:
logger.info("No new images to embed.")
return mapping
logger.info(f"Embedding {len(to_process)} images (batch size {batch_size})")
for i in tqdm(range(0, len(to_process), batch_size), desc="Images"):
batch = to_process[i: i + batch_size]
imgs, paths = [], []
for f in batch:
try:
imgs.append(Image.open(f).convert("RGB"))
paths.append(f)
except UnidentifiedImageError:
logger.warning(f"Skipping corrupted: {f}")
except Exception as e:
logger.warning(f"Error loading {f}: {e}")
if not imgs:
continue
# Embed batch
inputs = self.processor(images=imgs, return_tensors="pt")
inputs = {k: v.to(DEVICE) for k, v in inputs.items()}
with torch.no_grad():
last_hidden = self.vision_model(**inputs).last_hidden_state
# PATCH-ONLY: Use mean of patches, drop CLS token
patch_means = last_hidden[:, 1:, :].mean(dim=1) # (batch_size, 1024)
combo = nn.functional.normalize(patch_means, p=2, dim=1)
proj = self.projector(combo) # (batch_size, 4096)
embs = nn.functional.normalize(proj, p=2, dim=1).cpu().detach().numpy()
# Update and save
for pth, emb in zip(paths, embs):
mapping[str(pth)] = emb
with atomic_write(save_path, mode="wb", overwrite=True) as fp:
pickle.dump(mapping, fp, protocol=pickle.HIGHEST_PROTOCOL)
logger.info(f"Done. Total embeddings: {len(mapping)}")
return mapping
def embed_image(
self,
image_path: Union[str, Path],
save_path: Union[str, Path] = IMAGE_EMB_PATH,
no_save: bool = False,
) -> torch.Tensor:
"""
Embed one image using PATCH-LEVEL features only:
- prints the vector
- unless no_save=True, also appends to the pickle at save_path
Returns the torch.Tensor.
"""
img_path = Path(image_path)
try:
img = Image.open(img_path).convert("RGB")
vec = self.image_to_embedding(img)
except Exception as e:
logger.error(f"Failed embedding {img_path}: {e}")
raise
# Print out the vector
vec_np = vec.cpu().detach().numpy()
print(f"{img_path} → shape={vec_np.shape}")
print(vec_np)
# Append to pickle unless disabled
if not no_save:
save_path = Path(save_path)
save_path.parent.mkdir(parents=True, exist_ok=True)
try:
with open(save_path, "rb") as fp:
mapping = pickle.load(fp)
except (FileNotFoundError, pickle.PickleError):
mapping = {}
mapping[str(img_path)] = vec_np
with atomic_write(save_path, overwrite=True) as fp:
fp.write(pickle.dumps(mapping, protocol=pickle.HIGHEST_PROTOCOL))
logger.info(f"Saved embedding for {img_path} to {save_path}")
return vec
def main():
import argparse
parser = argparse.ArgumentParser(__doc__)
parser.add_argument("--image", type=str, help="Path to a single image to embed")
parser.add_argument("--folder", type=str, help="Path to a folder of images")
parser.add_argument("--out", type=str, default=str(IMAGE_EMB_PATH),
help="Pickle to save embeddings (default: embeddings/image_embeddings.pkl)")
parser.add_argument("--batch-size", type=int, default=DEFAULT_BATCH_SIZE,
help="Batch size for folder embedding")
parser.add_argument("--no-save", action="store_true",
help="When embedding a single image, do not save to the pickle")
args = parser.parse_args()
embedder = ImageEmbedder()
if args.image:
embedder.embed_image(
image_path=args.image,
save_path=args.out,
no_save=args.no_save
)
elif args.folder:
embedder.embed_folder(
folder_path=args.folder,
save_path=args.out,
batch_size=args.batch_size
)
else:
parser.error("Please specify --image or --folder")
if __name__ == "__main__":
main()