190 lines
7.2 KiB
Python
190 lines
7.2 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
from transformers import CLIPImageProcessor, CLIPVisionModel
|
|
from PIL import Image
|
|
import logging
|
|
|
|
# Configure logging
|
|
logging.basicConfig(level=logging.INFO)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Use CUDA if available
|
|
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
|
|
|
class ImageEmbedder:
|
|
"""
|
|
Image embedder that outputs 4096-dimensional embeddings using CLIP vision model.
|
|
Maintains compatibility with existing LLaVA weights while reducing dimension.
|
|
"""
|
|
|
|
def __init__(self,
|
|
vision_model_name: str = "openai/clip-vit-large-patch14-336",
|
|
proj_out_dim: int = 4096, # Changed from 5120 to 4096
|
|
llava_ckpt_path: str = None):
|
|
|
|
logger.info(f"Initializing ImageEmbedder with {vision_model_name}")
|
|
|
|
# Load CLIP vision encoder + processor (keep exactly the same)
|
|
self.processor = CLIPImageProcessor.from_pretrained(vision_model_name)
|
|
self.vision_model = (
|
|
CLIPVisionModel
|
|
.from_pretrained(vision_model_name)
|
|
.to(DEVICE)
|
|
.eval()
|
|
)
|
|
|
|
# Freeze vision model parameters
|
|
for p in self.vision_model.parameters():
|
|
p.requires_grad = False
|
|
|
|
# Get vision model's hidden dimension
|
|
vision_hidden_dim = self.vision_model.config.hidden_size # should be 1024
|
|
|
|
# Keep the simple projection architecture that works!
|
|
# Just change output dimension from 5120 to 4096
|
|
self.projection = nn.Linear(vision_hidden_dim, proj_out_dim).to(DEVICE)
|
|
self.projection.eval()
|
|
|
|
# Load LLaVA's projection weights if provided
|
|
if llava_ckpt_path is not None:
|
|
self._load_llava_weights(llava_ckpt_path, vision_hidden_dim, proj_out_dim)
|
|
|
|
logger.info(f"ImageEmbedder ready: {vision_hidden_dim}D -> {proj_out_dim}D on {DEVICE}")
|
|
|
|
def _load_llava_weights(self, ckpt_path, vision_dim, output_dim):
|
|
"""Load LLaVA projection weights, adapting for different output dimension"""
|
|
try:
|
|
ckpt = torch.load(ckpt_path, map_location=DEVICE)
|
|
|
|
# Extract projector weights (same as your original code)
|
|
keys = ["model.mm_projector.weight", "model.mm_projector.bias"]
|
|
|
|
if all(k in ckpt for k in keys):
|
|
original_weight = ckpt["model.mm_projector.weight"] # Shape: (5120, 1024)
|
|
original_bias = ckpt["model.mm_projector.bias"] # Shape: (5120,)
|
|
|
|
if output_dim == 5120:
|
|
# Direct loading for original dimension
|
|
state = {
|
|
"weight": original_weight,
|
|
"bias": original_bias
|
|
}
|
|
self.projection.load_state_dict(state)
|
|
logger.info("Loaded original LLaVA projection weights")
|
|
|
|
elif output_dim < 5120:
|
|
# Truncate weights for smaller dimension (preserves most important features)
|
|
truncated_weight = original_weight[:output_dim, :] # (4096, 1024)
|
|
truncated_bias = original_bias[:output_dim] # (4096,)
|
|
|
|
state = {
|
|
"weight": truncated_weight,
|
|
"bias": truncated_bias
|
|
}
|
|
self.projection.load_state_dict(state)
|
|
logger.info(f"Loaded truncated LLaVA weights: {5120}D -> {output_dim}D")
|
|
|
|
else:
|
|
# For larger dimensions, pad with Xavier initialization
|
|
new_weight = torch.zeros(output_dim, vision_dim, device=DEVICE)
|
|
new_bias = torch.zeros(output_dim, device=DEVICE)
|
|
|
|
# Copy original weights
|
|
new_weight[:5120, :] = original_weight
|
|
new_bias[:5120] = original_bias
|
|
|
|
# Initialize additional weights
|
|
nn.init.xavier_uniform_(new_weight[5120:, :])
|
|
nn.init.zeros_(new_bias[5120:])
|
|
|
|
state = {
|
|
"weight": new_weight,
|
|
"bias": new_bias
|
|
}
|
|
self.projection.load_state_dict(state)
|
|
logger.info(f"Loaded extended LLaVA weights: {5120}D -> {output_dim}D")
|
|
else:
|
|
logger.warning("LLaVA projector weights not found, using random initialization")
|
|
|
|
except Exception as e:
|
|
logger.warning(f"Failed to load LLaVA weights: {e}, using random initialization")
|
|
|
|
def image_to_embedding(self, image_input) -> torch.Tensor:
|
|
"""
|
|
Convert image(s) to embeddings - keeping your exact original logic!
|
|
|
|
Args:
|
|
image_input: PIL.Image or list of PIL.Images
|
|
|
|
Returns:
|
|
torch.Tensor: Embeddings of shape (proj_out_dim,) for single image
|
|
or (B, proj_out_dim) for batch of images
|
|
"""
|
|
# Handle single image vs batch (same as your original)
|
|
if isinstance(image_input, Image.Image):
|
|
images = [image_input]
|
|
single_image = True
|
|
else:
|
|
images = image_input
|
|
single_image = False
|
|
|
|
# Preprocess images (same as your original)
|
|
inputs = self.processor(images=images, return_tensors="pt")
|
|
pixel_values = inputs.pixel_values.to(DEVICE) # (B, 3, H, W)
|
|
|
|
# Forward pass (exactly like your original)
|
|
with torch.no_grad():
|
|
out = self.vision_model(pixel_values=pixel_values)
|
|
feat = out.pooler_output # (B, 1024)
|
|
emb = self.projection(feat) # (B, proj_out_dim)
|
|
|
|
# Return appropriate shape (same as your original)
|
|
if single_image:
|
|
return emb.squeeze(0) # (proj_out_dim,)
|
|
else:
|
|
return emb # (B, proj_out_dim)
|
|
|
|
def get_memory_usage(self):
|
|
"""Get current GPU memory usage if available"""
|
|
if torch.cuda.is_available():
|
|
return {
|
|
'allocated': torch.cuda.memory_allocated() / 1024 ** 3, # GB
|
|
'reserved': torch.cuda.memory_reserved() / 1024 ** 3, # GB
|
|
}
|
|
return {'allocated': 0, 'reserved': 0}
|
|
|
|
|
|
# Test function
|
|
def main():
|
|
"""Simple test of the embedder"""
|
|
import argparse
|
|
from pathlib import Path
|
|
|
|
parser = argparse.ArgumentParser(description="Test ImageEmbedder")
|
|
parser.add_argument("--image", type=str, help="Path to test image")
|
|
parser.add_argument("--model", type=str, default="openai/clip-vit-large-patch14-336")
|
|
parser.add_argument("--dim", type=int, default=4096)
|
|
parser.add_argument("--llava-ckpt", type=str, help="Path to LLaVA checkpoint")
|
|
args = parser.parse_args()
|
|
|
|
embedder = ImageEmbedder(
|
|
vision_model_name=args.model,
|
|
proj_out_dim=args.dim,
|
|
llava_ckpt_path=args.llava_ckpt
|
|
)
|
|
|
|
if args.image:
|
|
img = Image.open(args.image).convert("RGB")
|
|
emb = embedder.image_to_embedding(img)
|
|
print(f"Embedding shape: {emb.shape}")
|
|
print(f"Sample values: {emb[:10]}")
|
|
|
|
# Show memory usage
|
|
mem = embedder.get_memory_usage()
|
|
logger.info(f"GPU Memory - Allocated: {mem['allocated']:.2f}GB, Reserved: {mem['reserved']:.2f}GB")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|