Compare commits

..

3 Commits

Author SHA1 Message Date
ldy
2103e28aea Rough implementation of 5120->4096 dim 2025-06-18 16:02:01 +08:00
ldy
2ba5de55ba Merge remote-tracking branch 'origin/main' 2025-06-13 15:30:56 +08:00
ldy
a4f028c25d 更新 find_similar_img.py 2025-06-13 07:16:28 +00:00
5 changed files with 1440 additions and 106 deletions

View File

@ -1,17 +1,31 @@
import torch import torch
import torch.nn as nn
from transformers import CLIPImageProcessor, CLIPVisionModel from transformers import CLIPImageProcessor, CLIPVisionModel
from PIL import Image from PIL import Image
import logging
# Use CUDA # Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Use CUDA if available
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class ImageEmbedder: class ImageEmbedder:
"""
Image embedder that outputs 4096-dimensional embeddings using CLIP vision model.
Maintains compatibility with existing LLaVA weights while reducing dimension.
"""
def __init__(self, def __init__(self,
vision_model_name: str = "openai/clip-vit-large-patch14-336", vision_model_name: str = "openai/clip-vit-large-patch14-336",
proj_out_dim: int = 5120, proj_out_dim: int = 4096, # Changed from 5120 to 4096
llava_ckpt_path: str = None): llava_ckpt_path: str = None):
# Load CLIP vision encoder + processor
logger.info(f"Initializing ImageEmbedder with {vision_model_name}")
# Load CLIP vision encoder + processor (keep exactly the same)
self.processor = CLIPImageProcessor.from_pretrained(vision_model_name) self.processor = CLIPImageProcessor.from_pretrained(vision_model_name)
self.vision_model = ( self.vision_model = (
CLIPVisionModel CLIPVisionModel
@ -19,37 +33,157 @@ class ImageEmbedder:
.to(DEVICE) .to(DEVICE)
.eval() .eval()
) )
# Freeze version
# Freeze vision model parameters
for p in self.vision_model.parameters(): for p in self.vision_model.parameters():
p.requires_grad = False p.requires_grad = False
# Build the projection layer (1024 → proj_out_dim) and move it to DEVICE # Get vision model's hidden dimension
vision_hidden_dim = self.vision_model.config.hidden_size # should be 1024 vision_hidden_dim = self.vision_model.config.hidden_size # should be 1024
self.projection = torch.nn.Linear(vision_hidden_dim, proj_out_dim).to(DEVICE)
# Keep the simple projection architecture that works!
# Just change output dimension from 5120 to 4096
self.projection = nn.Linear(vision_hidden_dim, proj_out_dim).to(DEVICE)
self.projection.eval() self.projection.eval()
# Load LLaVAs projection weights from the full bin checkpoint # Load LLaVA's projection weights if provided
if llava_ckpt_path is not None: if llava_ckpt_path is not None:
ckpt = torch.load(llava_ckpt_path, map_location=DEVICE) self._load_llava_weights(llava_ckpt_path, vision_hidden_dim, proj_out_dim)
# extract only the projector weights + bias
logger.info(f"ImageEmbedder ready: {vision_hidden_dim}D -> {proj_out_dim}D on {DEVICE}")
def _load_llava_weights(self, ckpt_path, vision_dim, output_dim):
"""Load LLaVA projection weights, adapting for different output dimension"""
try:
ckpt = torch.load(ckpt_path, map_location=DEVICE)
# Extract projector weights (same as your original code)
keys = ["model.mm_projector.weight", "model.mm_projector.bias"] keys = ["model.mm_projector.weight", "model.mm_projector.bias"]
if all(k in ckpt for k in keys):
original_weight = ckpt["model.mm_projector.weight"] # Shape: (5120, 1024)
original_bias = ckpt["model.mm_projector.bias"] # Shape: (5120,)
if output_dim == 5120:
# Direct loading for original dimension
state = { state = {
k.replace("model.mm_projector.", ""): ckpt[k] "weight": original_weight,
for k in keys "bias": original_bias
} }
# Load into our linear layer
self.projection.load_state_dict(state) self.projection.load_state_dict(state)
logger.info("Loaded original LLaVA projection weights")
def image_to_embedding(self, image: Image.Image) -> torch.Tensor: elif output_dim < 5120:
# preprocess & move to device # Truncate weights for smaller dimension (preserves most important features)
inputs = self.processor(images=image, return_tensors="pt") truncated_weight = original_weight[:output_dim, :] # (4096, 1024)
pixel_values = inputs.pixel_values.to(DEVICE) # (1,3,H,W) truncated_bias = original_bias[:output_dim] # (4096,)
# Forward pass state = {
"weight": truncated_weight,
"bias": truncated_bias
}
self.projection.load_state_dict(state)
logger.info(f"Loaded truncated LLaVA weights: {5120}D -> {output_dim}D")
else:
# For larger dimensions, pad with Xavier initialization
new_weight = torch.zeros(output_dim, vision_dim, device=DEVICE)
new_bias = torch.zeros(output_dim, device=DEVICE)
# Copy original weights
new_weight[:5120, :] = original_weight
new_bias[:5120] = original_bias
# Initialize additional weights
nn.init.xavier_uniform_(new_weight[5120:, :])
nn.init.zeros_(new_bias[5120:])
state = {
"weight": new_weight,
"bias": new_bias
}
self.projection.load_state_dict(state)
logger.info(f"Loaded extended LLaVA weights: {5120}D -> {output_dim}D")
else:
logger.warning("LLaVA projector weights not found, using random initialization")
except Exception as e:
logger.warning(f"Failed to load LLaVA weights: {e}, using random initialization")
def image_to_embedding(self, image_input) -> torch.Tensor:
"""
Convert image(s) to embeddings - keeping your exact original logic!
Args:
image_input: PIL.Image or list of PIL.Images
Returns:
torch.Tensor: Embeddings of shape (proj_out_dim,) for single image
or (B, proj_out_dim) for batch of images
"""
# Handle single image vs batch (same as your original)
if isinstance(image_input, Image.Image):
images = [image_input]
single_image = True
else:
images = image_input
single_image = False
# Preprocess images (same as your original)
inputs = self.processor(images=images, return_tensors="pt")
pixel_values = inputs.pixel_values.to(DEVICE) # (B, 3, H, W)
# Forward pass (exactly like your original)
with torch.no_grad(): with torch.no_grad():
out = self.vision_model(pixel_values=pixel_values) out = self.vision_model(pixel_values=pixel_values)
feat = out.pooler_output # (1,1024) feat = out.pooler_output # (B, 1024)
emb = self.projection(feat) # (1,proj_out_dim) emb = self.projection(feat) # (B, proj_out_dim)
# Returns a 1D tensor of size [proj_out_dim] on DEVICE. # Return appropriate shape (same as your original)
return emb.squeeze(0) # → (proj_out_dim,) if single_image:
return emb.squeeze(0) # (proj_out_dim,)
else:
return emb # (B, proj_out_dim)
def get_memory_usage(self):
"""Get current GPU memory usage if available"""
if torch.cuda.is_available():
return {
'allocated': torch.cuda.memory_allocated() / 1024 ** 3, # GB
'reserved': torch.cuda.memory_reserved() / 1024 ** 3, # GB
}
return {'allocated': 0, 'reserved': 0}
# Test function
def main():
"""Simple test of the embedder"""
import argparse
from pathlib import Path
parser = argparse.ArgumentParser(description="Test ImageEmbedder")
parser.add_argument("--image", type=str, help="Path to test image")
parser.add_argument("--model", type=str, default="openai/clip-vit-large-patch14-336")
parser.add_argument("--dim", type=int, default=4096)
parser.add_argument("--llava-ckpt", type=str, help="Path to LLaVA checkpoint")
args = parser.parse_args()
embedder = ImageEmbedder(
vision_model_name=args.model,
proj_out_dim=args.dim,
llava_ckpt_path=args.llava_ckpt
)
if args.image:
img = Image.open(args.image).convert("RGB")
emb = embedder.image_to_embedding(img)
print(f"Embedding shape: {emb.shape}")
print(f"Sample values: {emb[:10]}")
# Show memory usage
mem = embedder.get_memory_usage()
logger.info(f"GPU Memory - Allocated: {mem['allocated']:.2f}GB, Reserved: {mem['reserved']:.2f}GB")
if __name__ == "__main__":
main()

View File

@ -1,36 +1,444 @@
import numpy as np import argparse
import pickle import pickle
import logging
from pathlib import Path
from typing import Dict, List, Optional
import numpy as np
from PIL import Image from PIL import Image
from sklearn.metrics.pairwise import cosine_similarity from sklearn.metrics.pairwise import cosine_similarity
from embedder import ImageEmbedder from embedder import ImageEmbedder
# ——— Load stored embeddings & mapping ——— # Configure logging
vecs = np.load("imgVecs/image_vectors.npy") # (N, D) logging.basicConfig(level=logging.INFO)
with open("imgVecs/index_to_file.pkl", "rb") as f: logger = logging.getLogger(__name__)
idx2file = pickle.load(f) # dict: idx → filepath
# ——— Specify query image ——— # Default paths
query_path = "datasets/coco/val2017/1140002154.jpg" DEFAULT_IMG_EMB_PATH = Path("embeddings/image_embeddings.pkl")
DEFAULT_VISION_MODEL = "openai/clip-vit-large-patch14-336"
DEFAULT_TOP_K = 10
# ——— Embed the query image ———
embedder = ImageEmbedder(
vision_model_name="openai/clip-vit-large-patch14-336",
proj_out_dim=5120,
llava_ckpt_path="datasets/pytorch_model-00003-of-00003.bin"
)
img = Image.open(query_path).convert("RGB")
q_vec = embedder.image_to_embedding(img).cpu().numpy() # (D,)
# ——— Compute similarities & retrieve top-k ——— class ImageSimilaritySearcher:
# (you can normalize if you built your DB with inner-product indexing) """
# vecs_norm = vecs / np.linalg.norm(vecs, axis=1, keepdims=True) Image-to-image similarity search using unified embeddings dictionary
# q_vec_norm = q_vec / np.linalg.norm(q_vec) """
# sims = cosine_similarity(q_vec_norm.reshape(1, -1), vecs_norm).flatten()
sims = cosine_similarity(q_vec.reshape(1, -1), vecs).flatten()
top5 = sims.argsort()[-5:][::-1]
# ——— Print out the results ——— def __init__(
print(f"Query image: {query_path}\n") self,
for rank, idx in enumerate(top5, 1): image_embeddings_path: Path,
print(f"{rank:>2}. {idx2file[idx]} (score: {sims[idx]:.4f})") vision_model_name: str = DEFAULT_VISION_MODEL,
embedding_dim: int = 4096,
llava_ckpt_path: Optional[str] = None
):
self.image_embeddings_path = Path(image_embeddings_path)
self.embedding_dim = embedding_dim
# Load image embeddings
logger.info(f"Loading image embeddings from {self.image_embeddings_path}")
self.image_data = self._load_image_embeddings()
# Initialize image embedder for query images
logger.info(f"Initializing image embedder: {vision_model_name}")
self.image_embedder = ImageEmbedder(
vision_model_name=vision_model_name,
proj_out_dim=embedding_dim,
llava_ckpt_path=llava_ckpt_path
)
logger.info(f"ImageSimilaritySearcher ready with {len(self.image_data)} images")
def _load_image_embeddings(self) -> Dict[str, np.ndarray]:
"""Load image embeddings from unified pickle file"""
try:
with open(self.image_embeddings_path, "rb") as f:
data = pickle.load(f)
if not isinstance(data, dict):
raise ValueError("Image embeddings file should contain a dictionary")
# Verify embedding dimensions
if data:
sample_key = next(iter(data))
sample_emb = data[sample_key]
if sample_emb.shape[-1] != self.embedding_dim:
raise ValueError(f"Expected {self.embedding_dim}D embeddings, got {sample_emb.shape[-1]}D")
logger.info(f"Loaded {len(data)} image embeddings of dimension {sample_emb.shape[-1]}")
else:
logger.warning("Empty embeddings dictionary loaded")
return data
except FileNotFoundError:
raise FileNotFoundError(f"Image embeddings file not found: {self.image_embeddings_path}")
except Exception as e:
raise RuntimeError(f"Failed to load image embeddings: {e}")
def search_by_image_path(
self,
query_image_path: str,
top_k: int = DEFAULT_TOP_K,
exclude_self: bool = True,
return_scores: bool = False
) -> List[Dict]:
"""
Search for similar images given a query image path
Args:
query_image_path: Path to query image
top_k: Number of top results to return
exclude_self: Whether to exclude the query image from results
return_scores: Whether to include similarity scores
Returns:
List of dictionaries with 'path' and optionally 'score' keys
"""
logger.info(f"Searching for images similar to: {query_image_path}")
# Load and embed query image
try:
query_img = Image.open(query_image_path).convert("RGB")
query_embedding = self.image_embedder.image_to_embedding(query_img).cpu().numpy()
except Exception as e:
raise RuntimeError(f"Failed to process query image {query_image_path}: {e}")
return self._search_by_embedding(
query_embedding,
query_image_path if exclude_self else None,
top_k,
return_scores
)
def search_by_image(
self,
query_image: Image.Image,
top_k: int = DEFAULT_TOP_K,
return_scores: bool = False
) -> List[Dict]:
"""
Search for similar images given a PIL Image
Args:
query_image: PIL Image object
top_k: Number of top results to return
return_scores: Whether to include similarity scores
Returns:
List of dictionaries with 'path' and optionally 'score' keys
"""
logger.info("Searching for similar images using provided image")
# Embed query image
query_embedding = self.image_embedder.image_to_embedding(query_image).cpu().numpy()
return self._search_by_embedding(query_embedding, None, top_k, return_scores)
def _search_by_embedding(
self,
query_embedding: np.ndarray,
exclude_path: Optional[str],
top_k: int,
return_scores: bool
) -> List[Dict]:
"""Internal method to search by embedding vector"""
# Prepare database embeddings
image_paths = list(self.image_data.keys())
image_embeddings = np.stack(list(self.image_data.values())) # (N, D)
# Compute cosine similarities
similarities = cosine_similarity(
query_embedding.reshape(1, -1),
image_embeddings
).flatten()
# Handle exclusion
valid_indices = list(range(len(image_paths)))
if exclude_path:
try:
exclude_idx = image_paths.index(exclude_path)
valid_indices.remove(exclude_idx)
similarities = np.delete(similarities, exclude_idx)
image_paths = [p for i, p in enumerate(image_paths) if i != exclude_idx]
except ValueError:
pass # Query path not in database
# Get top-k indices
top_k_indices = np.argsort(-similarities)[:top_k]
# Prepare results
results = []
for idx in top_k_indices:
result = {'path': image_paths[idx]}
if return_scores:
result['score'] = float(similarities[idx])
results.append(result)
return results
def batch_search(
self,
query_image_paths: List[str],
top_k: int = DEFAULT_TOP_K,
exclude_self: bool = True
) -> Dict[str, List[Dict]]:
"""
Search for multiple query images at once
Args:
query_image_paths: List of paths to query images
top_k: Number of results per query
exclude_self: Whether to exclude query images from results
Returns:
Dictionary mapping query path to list of results
"""
logger.info(f"Batch searching {len(query_image_paths)} images")
results = {}
for query_path in query_image_paths:
try:
results[query_path] = self.search_by_image_path(
query_path,
top_k=top_k,
exclude_self=exclude_self,
return_scores=True
)
except Exception as e:
logger.error(f"Failed to search for {query_path}: {e}")
results[query_path] = []
return results
def get_similarity_stats(self, query_image_path: str) -> Dict:
"""
Get statistics about similarity scores for a query image
Args:
query_image_path: Path to query image
Returns:
Dictionary with similarity statistics
"""
try:
query_img = Image.open(query_image_path).convert("RGB")
query_embedding = self.image_embedder.image_to_embedding(query_img).cpu().numpy()
except Exception as e:
raise RuntimeError(f"Failed to process query image: {e}")
image_embeddings = np.stack(list(self.image_data.values()))
similarities = cosine_similarity(
query_embedding.reshape(1, -1),
image_embeddings
).flatten()
return {
'mean_similarity': float(np.mean(similarities)),
'max_similarity': float(np.max(similarities)),
'min_similarity': float(np.min(similarities)),
'std_similarity': float(np.std(similarities)),
'median_similarity': float(np.median(similarities))
}
def get_memory_usage(self):
"""Get memory usage information"""
return self.image_embedder.get_memory_usage()
def print_results(results: List[Dict], query: str, show_scores: bool = True):
"""Pretty print search results"""
print(f"\nTop {len(results)} similar images to: {Path(query).name}")
print(f"Query: {query}")
print("-" * 100)
for i, result in enumerate(results, 1):
path = result['path']
if show_scores and 'score' in result:
score = result['score']
print(f"{i:2d}. {path:<80} score={score:>7.4f}")
else:
print(f"{i:2d}. {path}")
print("-" * 100)
def main():
parser = argparse.ArgumentParser(description="Image-to-Image Similarity Search")
# Input/Output paths
parser.add_argument(
"--image-embeddings",
type=str,
default=str(DEFAULT_IMG_EMB_PATH),
help="Path to image embeddings pickle file"
)
# Query options
parser.add_argument(
"--query-image",
type=str,
help="Path to query image"
)
parser.add_argument(
"--query-images-file",
type=str,
help="File with one image path per line"
)
parser.add_argument(
"--query-dir",
type=str,
help="Directory containing query images"
)
# Model parameters
parser.add_argument(
"--vision-model",
type=str,
default=DEFAULT_VISION_MODEL,
help="CLIP vision model to use"
)
parser.add_argument(
"--embedding-dim",
type=int,
default=4096,
help="Embedding dimension"
)
parser.add_argument(
"--llava-ckpt",
type=str,
help="Path to LLaVA checkpoint for projection weights"
)
# Search parameters
parser.add_argument(
"--top-k",
type=int,
default=DEFAULT_TOP_K,
help="Number of top results to return"
)
parser.add_argument(
"--include-self",
action="store_true",
help="Include query image in results"
)
# Output options
parser.add_argument(
"--output",
type=str,
help="Save results to JSON file"
)
parser.add_argument(
"--stats",
action="store_true",
help="Show similarity statistics"
)
args = parser.parse_args()
# Initialize searcher
try:
searcher = ImageSimilaritySearcher(
image_embeddings_path=args.image_embeddings,
vision_model_name=args.vision_model,
embedding_dim=args.embedding_dim,
llava_ckpt_path=args.llava_ckpt
)
except Exception as e:
logger.error(f"Failed to initialize searcher: {e}")
return
# Show memory usage
mem = searcher.get_memory_usage()
logger.info(f"GPU Memory - Allocated: {mem['allocated']:.2f}GB, Reserved: {mem['reserved']:.2f}GB")
# Handle different query modes
if args.query_image:
# Single query image
results = searcher.search_by_image_path(
args.query_image,
top_k=args.top_k,
exclude_self=not args.include_self,
return_scores=True
)
print_results(results, args.query_image)
if args.stats:
stats = searcher.get_similarity_stats(args.query_image)
print(f"\nSimilarity Statistics:")
for key, value in stats.items():
print(f" {key}: {value:.4f}")
elif args.query_images_file:
# Batch queries from file
try:
with open(args.query_images_file, 'r', encoding='utf-8') as f:
query_paths = [line.strip() for line in f if line.strip()]
batch_results = searcher.batch_search(
query_paths,
top_k=args.top_k,
exclude_self=not args.include_self
)
for query_path, results in batch_results.items():
if results: # Only show if we got results
print_results(results, query_path)
except Exception as e:
logger.error(f"Failed to process queries file: {e}")
return
elif args.query_dir:
# Query all images in directory
query_dir = Path(args.query_dir)
if not query_dir.exists():
logger.error(f"Query directory does not exist: {query_dir}")
return
# Find images in directory
image_exts = {'.jpg', '.jpeg', '.png', '.bmp', '.gif', '.webp', '.tiff', '.tif'}
query_paths = []
for ext in image_exts:
query_paths.extend(query_dir.glob(f"*{ext}"))
query_paths.extend(query_dir.glob(f"*{ext.upper()}"))
query_paths = [str(p) for p in sorted(query_paths)]
if not query_paths:
logger.error(f"No images found in {query_dir}")
return
logger.info(f"Found {len(query_paths)} images in {query_dir}")
batch_results = searcher.batch_search(
query_paths,
top_k=args.top_k,
exclude_self=not args.include_self
)
for query_path, results in batch_results.items():
if results: # Only show if we got results
print_results(results, query_path)
else:
print("Please specify a query image, query images file, or query directory")
print("Use --help for more information")
return
# Save results if requested
if args.output and 'batch_results' in locals():
import json
try:
with open(args.output, 'w', encoding='utf-8') as f:
json.dump(batch_results, f, indent=2, ensure_ascii=False)
logger.info(f"Results saved to {args.output}")
except Exception as e:
logger.error(f"Failed to save results: {e}")
if __name__ == "__main__":
main()

View File

@ -1,101 +1,231 @@
import os import os
import argparse import argparse
import pickle import pickle
import logging
from pathlib import Path
from typing import Dict
import numpy as np import numpy as np
from PIL import Image from PIL import Image
import torch import torch
from tqdm import tqdm
from atomicwrites import atomic_write
from embedder import ImageEmbedder, DEVICE from embedder import ImageEmbedder, DEVICE
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Default paths
DEFAULT_OUTPUT_DIR = Path("embeddings")
DEFAULT_EMB_FILE = "image_embeddings.pkl"
def parse_args(): def parse_args():
p = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description="Batch-convert a folder of images into embedding vectors" description="Batch-convert images into 4096D embeddings stored in unified dictionary"
) )
p.add_argument( parser.add_argument(
"--image-dir", required=True, "--image-dir", required=True,
help="Path to a folder containing images (jpg/png/bmp/gif)" help="Path to folder containing images (jpg/png/bmp/gif/webp)"
) )
p.add_argument( parser.add_argument(
"--output-dir", default="processed_images", "--output-dir", default=str(DEFAULT_OUTPUT_DIR),
help="Where to save image_vectors.npy and index_to_file.pkl" help="Directory to save embeddings file"
) )
p.add_argument( parser.add_argument(
"--output-file", default=DEFAULT_EMB_FILE,
help="Filename for embeddings dictionary (pkl format)"
)
parser.add_argument(
"--vision-model", default="openai/clip-vit-large-patch14-336", "--vision-model", default="openai/clip-vit-large-patch14-336",
help="Hugging Face name of the CLIP vision encoder" help="Hugging Face CLIP vision model name"
) )
p.add_argument( parser.add_argument(
"--proj-dim", type=int, default=5120, "--embedding-dim", type=int, default=4096,
help="Dimensionality of the projection output" help="Output embedding dimension"
) )
p.add_argument( parser.add_argument(
"--llava-ckpt", default=None, "--llava-ckpt", default=None,
help="(Optional) full LLaVA checkpoint .bin to load projector weights from" help="(Optional) LLaVA checkpoint to load projection weights from"
) )
p.add_argument( parser.add_argument(
"--batch-size", type=int, default=64, "--batch-size", type=int, default=32,
help="How many images to encode per GPU/CPU batch" help="Batch size for processing images"
) )
return p.parse_args() parser.add_argument(
"--resume", action="store_true",
help="Resume from existing embeddings file (skip already processed images)"
)
return parser.parse_args()
def find_images(folder): def find_images(folder: str) -> list:
exts = (".jpg", ".jpeg", ".png", ".bmp", ".gif") """Find all supported image files in folder (non-recursive by default)"""
return sorted([ supported_exts = (".jpg", ".jpeg", ".png", ".bmp", ".gif", ".webp", ".tiff", ".tif")
os.path.join(folder, fname) image_paths = []
for fname in os.listdir(folder)
if fname.lower().endswith(exts) folder_path = Path(folder)
])
# Use a set to avoid duplicates
found_files = set()
# Search for files with supported extensions (case-insensitive)
for file_path in folder_path.iterdir():
if file_path.is_file():
file_ext = file_path.suffix.lower()
if file_ext in supported_exts:
found_files.add(str(file_path.resolve()))
return sorted(list(found_files))
def load_existing_embeddings(filepath: Path) -> Dict[str, np.ndarray]:
"""Load existing embeddings dictionary"""
if not filepath.exists():
return {}
try:
with open(filepath, "rb") as f:
embeddings = pickle.load(f)
logger.info(f"Loaded {len(embeddings)} existing embeddings from {filepath}")
return embeddings
except Exception as e:
logger.warning(f"Failed to load existing embeddings: {e}")
return {}
def save_embeddings(embeddings: Dict[str, np.ndarray], filepath: Path):
"""Safely save embeddings dictionary"""
filepath.parent.mkdir(parents=True, exist_ok=True)
with atomic_write(filepath, mode='wb', overwrite=True) as f:
pickle.dump(embeddings, f, protocol=pickle.HIGHEST_PROTOCOL)
logger.info(f"Saved {len(embeddings)} embeddings to {filepath}")
def main(): def main():
args = parse_args() args = parse_args()
os.makedirs(args.output_dir, exist_ok=True)
# Setup paths
output_dir = Path(args.output_dir)
output_file = output_dir / args.output_file
# Discover images # Discover images
logger.info(f"Scanning for images in {args.image_dir}")
image_paths = find_images(args.image_dir) image_paths = find_images(args.image_dir)
if not image_paths:
raise RuntimeError(f"No images found in {args.image_dir}")
print(f"Found {len(image_paths)} images in {args.image_dir}")
# Build embedder if not image_paths:
raise RuntimeError(f"No supported images found in {args.image_dir}")
logger.info(f"Found {len(image_paths)} images")
# Load existing embeddings if resuming
embeddings_dict = {}
if args.resume:
embeddings_dict = load_existing_embeddings(output_file)
# Filter out already processed images
new_image_paths = [p for p in image_paths if p not in embeddings_dict]
logger.info(f"Resuming: {len(embeddings_dict)} already processed, {len(new_image_paths)} remaining")
image_paths = new_image_paths
if not image_paths:
logger.info("All images already processed!")
return
# Initialize embedder
logger.info(f"Initializing ImageEmbedder on {DEVICE}")
embedder = ImageEmbedder( embedder = ImageEmbedder(
vision_model_name=args.vision_model, vision_model_name=args.vision_model,
proj_out_dim=args.proj_dim, proj_out_dim=args.embedding_dim,
llava_ckpt_path=args.llava_ckpt llava_ckpt_path=args.llava_ckpt
) )
print(f"Using device: {DEVICE}")
# Process in batches # Show memory usage
all_embs = [] mem = embedder.get_memory_usage()
index_to_file = {} logger.info(f"GPU Memory - Allocated: {mem['allocated']:.2f}GB, Reserved: {mem['reserved']:.2f}GB")
# Process images in batches
processed_count = 0
failed_count = 0
progress_bar = tqdm(total=len(image_paths), desc="Processing images")
for batch_start in range(0, len(image_paths), args.batch_size): for batch_start in range(0, len(image_paths), args.batch_size):
batch_paths = image_paths[batch_start:batch_start + args.batch_size] batch_end = min(batch_start + args.batch_size, len(image_paths))
# load images batch_paths = image_paths[batch_start:batch_end]
imgs = [Image.open(p).convert("RGB") for p in batch_paths]
# embed # Load batch images
batch_images = []
valid_paths = []
for img_path in batch_paths:
try:
img = Image.open(img_path).convert("RGB")
batch_images.append(img)
valid_paths.append(img_path)
except Exception as e:
logger.warning(f"Failed to load {img_path}: {e}")
failed_count += 1
progress_bar.update(1)
continue
if not batch_images:
continue
# Generate embeddings for batch
try:
with torch.no_grad(): with torch.no_grad():
embs = embedder.image_to_embedding(imgs) # (B, D) if len(batch_images) == 1:
embs_np = embs.cpu().numpy() embeddings = embedder.image_to_embedding(batch_images[0]).cpu().numpy()
all_embs.append(embs_np) embeddings = embeddings.reshape(1, -1) # Make it (1, D)
# record mapping else:
for i, p in enumerate(batch_paths): embeddings = embedder.image_to_embedding(batch_images).cpu().numpy() # (B, D)
index_to_file[batch_start + i] = p
print(f" • Processed {batch_start + len(batch_paths)}/{len(image_paths)} images") # Store embeddings in dictionary
for img_path, embedding in zip(valid_paths, embeddings):
embeddings_dict[img_path] = embedding
processed_count += 1
# Stack and save progress_bar.update(len(valid_paths))
vectors = np.vstack(all_embs) # shape (N, D)
vec_file = os.path.join(args.output_dir, "image_vectors.npy")
map_file = os.path.join(args.output_dir, "index_to_file.pkl")
np.save(vec_file, vectors) # Periodic save every 10 batches
with open(map_file, "wb") as f: if (batch_start // args.batch_size) % 10 == 0:
pickle.dump(index_to_file, f) save_embeddings(embeddings_dict, output_file)
print(f"\nSaved {vectors.shape[0]}×{vectors.shape[1]} vectors to\n {vec_file}") except Exception as e:
print(f"Saved index→file mapping to\n {map_file}") logger.error(f"Failed to process batch starting at {batch_start}: {e}")
failed_count += len(batch_paths)
progress_bar.update(len(batch_paths))
continue
progress_bar.close()
# Final save
save_embeddings(embeddings_dict, output_file)
# Summary
total_processed = len(embeddings_dict)
logger.info(f"\nProcessing complete!")
logger.info(f"Total embeddings: {total_processed}")
logger.info(f"Newly processed: {processed_count}")
logger.info(f"Failed: {failed_count}")
logger.info(f"Embedding dimension: {args.embedding_dim}")
logger.info(f"Output file: {output_file}")
# Verify a sample embedding
if embeddings_dict:
sample_path = next(iter(embeddings_dict))
sample_emb = embeddings_dict[sample_path]
logger.info(f"Sample embedding shape: {sample_emb.shape}")
logger.info(f"Sample embedding norm: {np.linalg.norm(sample_emb):.4f}")
# Final memory usage
mem = embedder.get_memory_usage()
logger.info(f"Final GPU Memory - Allocated: {mem['allocated']:.2f}GB, Reserved: {mem['reserved']:.2f}GB")
if __name__ == "__main__": if __name__ == "__main__":

306
text_embedder.py Normal file
View File

@ -0,0 +1,306 @@
import pickle
import hashlib
import logging
from pathlib import Path
from typing import Dict, List, Union
import torch
import torch.nn as nn
import numpy as np
from transformers import CLIPTokenizer, CLIPTextModel
from cachetools import LRUCache
from tqdm import tqdm
from atomicwrites import atomic_write
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Device & defaults
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Use CLIP which is specifically designed for text-image alignment
DEFAULT_MODEL = "openai/clip-vit-base-patch32"
OUTPUT_DIM = 4096 # Match image embeddings dimension
CHUNK_BATCH_SIZE = 16
ARTICLE_CACHE_SIZE = 1024
# Persistence paths
EMBEDDINGS_DIR = Path(__file__).parent / "embeddings"
TEXT_EMB_PATH = EMBEDDINGS_DIR / "text_embeddings.pkl"
class CLIPTextEmbedder:
"""
Text embedder using CLIP's text encoder, optimized for text-to-image
retrieval with 4096-dimensional output to match image embeddings.
"""
def __init__(
self,
model_name: str = DEFAULT_MODEL,
output_dim: int = OUTPUT_DIM,
chunk_batch_size: int = CHUNK_BATCH_SIZE,
):
logger.info(f"Initializing CLIPTextEmbedder with {model_name}")
# Load CLIP text encoder
self.tokenizer = CLIPTokenizer.from_pretrained(model_name)
self.text_encoder = CLIPTextModel.from_pretrained(model_name).to(DEVICE).eval()
# Freeze the text encoder
for p in self.text_encoder.parameters():
p.requires_grad = False
# Get CLIP's text embedding dimension (usually 512)
clip_dim = self.text_encoder.config.hidden_size
# Enhanced projector to match 4096-D image embeddings
self.projector = nn.Sequential(
nn.Linear(clip_dim, output_dim // 2),
nn.LayerNorm(output_dim // 2),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(output_dim // 2, output_dim),
nn.LayerNorm(output_dim),
nn.ReLU(),
nn.Linear(output_dim, output_dim),
).to(DEVICE)
# Initialize projector with Xavier initialization
for layer in self.projector:
if isinstance(layer, nn.Linear):
nn.init.xavier_uniform_(layer.weight)
nn.init.zeros_(layer.bias)
self.article_cache = LRUCache(maxsize=ARTICLE_CACHE_SIZE)
self.chunk_batch_size = chunk_batch_size
# Ensure embeddings dir
EMBEDDINGS_DIR.mkdir(parents=True, exist_ok=True)
logger.info(f"CLIPTextEmbedder ready: {clip_dim}D -> {output_dim}D on {DEVICE}")
def text_to_embedding(self, text: Union[str, List[str]]) -> torch.Tensor:
"""
Embed text using CLIP's text encoder + projector
Args:
text: Single string or list of strings
Returns:
torch.Tensor: Normalized embeddings of shape (D,) for single text
or (B, D) for batch of texts
"""
# Handle single text vs batch
if isinstance(text, str):
texts = [text]
single_text = True
else:
texts = text
single_text = False
# Check cache for single text
if single_text and texts[0] in self.article_cache:
return self.article_cache[texts[0]]
# Tokenize and encode with CLIP
inputs = self.tokenizer(
texts,
return_tensors="pt",
truncation=True,
padding=True,
max_length=77 # CLIP's max length
).to(DEVICE)
with torch.no_grad():
# Get CLIP text features
text_outputs = self.text_encoder(**inputs)
# Use pooled output (CLS token equivalent)
clip_features = text_outputs.pooler_output # [batch_size, clip_dim]
# Normalize CLIP features
clip_features = nn.functional.normalize(clip_features, p=2, dim=1)
# Project to target dimension
projected = self.projector(clip_features)
# Normalize final output
projected = nn.functional.normalize(projected, p=2, dim=1)
# Cache single text result
if single_text:
result = projected.squeeze(0).cpu()
self.article_cache[texts[0]] = result
return result
else:
return projected.cpu() # Return batch
def embed_texts_from_file(
self,
text_file: Union[str, Path],
save_path: Union[str, Path] = TEXT_EMB_PATH
) -> Dict[str, np.ndarray]:
"""
Read newline-separated texts and embed them
"""
save_path = Path(save_path)
save_path.parent.mkdir(parents=True, exist_ok=True)
# Load existing embeddings
try:
with open(save_path, "rb") as fp:
mapping: Dict[str, np.ndarray] = pickle.load(fp)
logger.info(f"Loaded {len(mapping)} existing text embeddings")
except FileNotFoundError:
mapping = {}
except Exception as e:
logger.warning(f"Failed loading {save_path}: {e}")
mapping = {}
# Read text file
try:
with open(text_file, "r", encoding="utf-8") as fp:
lines = [ln.strip() for ln in fp if ln.strip()]
except Exception as e:
raise RuntimeError(f"Error reading {text_file}: {e}")
logger.info(f"Processing {len(lines)} texts from {text_file}")
# Process in batches
new_texts = [ln for ln in lines if ln not in mapping]
logger.info(f"Embedding {len(new_texts)} new texts")
if not new_texts:
logger.info("All texts already embedded!")
return mapping
# Process in batches
for i in tqdm(range(0, len(new_texts), self.chunk_batch_size), desc="Embedding texts"):
batch_texts = new_texts[i:i + self.chunk_batch_size]
try:
# Get batch embeddings
batch_embeddings = self.text_to_embedding(batch_texts)
# Store in mapping
for text, embedding in zip(batch_texts, batch_embeddings):
mapping[text] = embedding.numpy()
except Exception as e:
logger.error(f"Embedding failed for batch starting at {i}: {e}")
# Process individually as fallback
for text in batch_texts:
try:
vec = self.text_to_embedding(text)
mapping[text] = vec.numpy()
except Exception as e2:
logger.error(f"Individual embedding failed for text: {e2}")
# Save updated mappings
with atomic_write(save_path, mode='wb', overwrite=True) as f:
pickle.dump(mapping, f, protocol=pickle.HIGHEST_PROTOCOL)
logger.info(f"Saved {len(mapping)} text embeddings to {save_path}")
return mapping
def embed_single_text(self, text: str) -> np.ndarray:
"""Embed a single text and return as numpy array"""
embedding = self.text_to_embedding(text)
return embedding.numpy()
def clear_cache(self):
"""Clear the text cache"""
self.article_cache.clear()
logger.info("Text cache cleared")
def get_memory_usage(self):
"""Get current GPU memory usage if available."""
if torch.cuda.is_available():
return {
'allocated': torch.cuda.memory_allocated() / 1024 ** 3, # GB
'reserved': torch.cuda.memory_reserved() / 1024 ** 3, # GB
}
return {'allocated': 0, 'reserved': 0}
# Alias for compatibility
TextEmbedder = CLIPTextEmbedder
def main():
import argparse
parser = argparse.ArgumentParser("CLIP Text Embedder for Image Search")
parser.add_argument("--text", type=str, help="Single text to embed")
parser.add_argument("--file", type=str, help="Path to newline-delimited texts")
parser.add_argument(
"--out", type=str, default=str(TEXT_EMB_PATH),
help="Pickle path for saving embeddings"
)
parser.add_argument(
"--batch-size", type=int, default=CHUNK_BATCH_SIZE,
help="Batch size for processing"
)
parser.add_argument(
"--model", type=str, default=DEFAULT_MODEL,
help="CLIP model name to use"
)
parser.add_argument(
"--dim", type=int, default=OUTPUT_DIM,
help="Output embedding dimension"
)
args = parser.parse_args()
embedder = CLIPTextEmbedder(
model_name=args.model,
output_dim=args.dim,
chunk_batch_size=args.batch_size,
)
# Show memory usage
mem = embedder.get_memory_usage()
logger.info(f"GPU Memory - Allocated: {mem['allocated']:.2f}GB, Reserved: {mem['reserved']:.2f}GB")
# Single-text mode
if args.text:
vec = embedder.embed_single_text(args.text)
print(f"Text: {args.text}")
print(f"Shape: {vec.shape}")
print(f"Norm: {np.linalg.norm(vec):.4f}")
print(f"Sample values: {vec[:10]}")
# Save to file
save_path = Path(args.out)
save_path.parent.mkdir(parents=True, exist_ok=True)
try:
with open(save_path, "rb") as fp:
mapping = pickle.load(fp)
except (FileNotFoundError, pickle.PickleError):
mapping = {}
mapping[args.text] = vec
with atomic_write(save_path, mode='wb', overwrite=True) as f:
pickle.dump(mapping, f, protocol=pickle.HIGHEST_PROTOCOL)
logger.info(f"Saved text embedding to {save_path}")
# Batch-file mode
if args.file:
embedder.embed_texts_from_file(args.file, save_path=args.out)
logger.info(f"Completed batch embedding from {args.file}")
if not args.text and not args.file:
# Demo mode
demo_texts = [
"a cat sleeping on a bed",
"a dog running in the park",
"beautiful sunset over mountains",
"red sports car on highway"
]
print("Demo: Embedding sample texts...")
for text in demo_texts:
vec = embedder.embed_single_text(text)
print(f"'{text}' -> shape {vec.shape}, norm {np.linalg.norm(vec):.4f}")
if __name__ == "__main__":
main()

356
text_search_starter.py Normal file
View File

@ -0,0 +1,356 @@
import argparse
import pickle
import logging
from pathlib import Path
from typing import Dict, List, Optional
import numpy as np
import torch
from text_embedder import CLIPTextEmbedder
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Default paths
DEFAULT_IMG_EMB_PATH = Path("embeddings/image_embeddings.pkl")
DEFAULT_TEXT_MODEL = "openai/clip-vit-base-patch32"
DEFAULT_TOP_K = 10
class TextToImageSearcher:
"""
Text-to-image search using CLIP embeddings
"""
def __init__(
self,
image_embeddings_path: Path,
text_model_name: str = DEFAULT_TEXT_MODEL,
embedding_dim: int = 4096
):
"""
Initialize the Text-to-Image searcher
Args:
image_embeddings_path: Path to the pickle file containing image embeddings
text_model_name: Name of the CLIP text model to use
embedding_dim: Dimension of the embeddings (should match image embeddings)
"""
self.image_embeddings_path = Path(image_embeddings_path)
self.embedding_dim = embedding_dim
# Load image embeddings
logger.info(f"Loading image embeddings from {self.image_embeddings_path}")
self.image_data = self._load_image_embeddings()
# Initialize text embedder
logger.info(f"Initializing text embedder: {text_model_name}")
self.text_embedder = CLIPTextEmbedder(
model_name=text_model_name,
output_dim=embedding_dim
)
logger.info(f"TextToImageSearcher ready with {len(self.image_data)} images")
def _load_image_embeddings(self) -> Dict[str, np.ndarray]:
"""Load image embeddings from pickle file"""
try:
with open(self.image_embeddings_path, "rb") as f:
data = pickle.load(f)
if not isinstance(data, dict):
raise ValueError("Image embeddings file should contain a dictionary")
# Verify embedding dimensions
sample_key = next(iter(data))
sample_emb = data[sample_key]
if sample_emb.shape[-1] != self.embedding_dim:
raise ValueError(f"Expected {self.embedding_dim}D embeddings, got {sample_emb.shape[-1]}D")
logger.info(f"Loaded {len(data)} image embeddings of dimension {sample_emb.shape[-1]}")
return data
except FileNotFoundError:
raise FileNotFoundError(f"Image embeddings file not found: {self.image_embeddings_path}")
except Exception as e:
raise RuntimeError(f"Failed to load image embeddings: {e}")
def search(
self,
query_text: str,
top_k: int = DEFAULT_TOP_K,
return_scores: bool = False
) -> List[Dict]:
"""
Search for images matching the query text
Args:
query_text: Text description to search for
top_k: Number of top results to return
return_scores: Whether to include similarity scores
Returns:
List of dictionaries with 'path' and optionally 'score' keys
"""
logger.info(f"Searching for: '{query_text}'")
# Embed query text
query_embedding = self.text_embedder.embed_single_text(query_text)
# Prepare image data for comparison
image_paths = list(self.image_data.keys())
image_embeddings = np.stack(list(self.image_data.values())) # (N, D)
# Compute cosine similarities (both embeddings are normalized)
similarities = np.dot(image_embeddings, query_embedding)
# Get top-k indices
top_k_indices = np.argsort(-similarities)[:top_k]
# Prepare results
results = []
for idx in top_k_indices:
result = {'path': image_paths[idx]}
if return_scores:
result['score'] = float(similarities[idx])
results.append(result)
return results
def batch_search(
self,
query_texts: List[str],
top_k: int = DEFAULT_TOP_K
) -> Dict[str, List[Dict]]:
"""
Search for multiple queries at once
Args:
query_texts: List of text descriptions
top_k: Number of results per query
Returns:
Dictionary mapping query text to list of results
"""
logger.info(f"Batch searching {len(query_texts)} queries")
results = {}
for query in query_texts:
results[query] = self.search(query, top_k=top_k, return_scores=True)
return results
def get_search_stats(self, query_text: str) -> Dict:
"""
Get statistics about search results
Args:
query_text: Query to analyze
Returns:
Dictionary with similarity statistics
"""
query_embedding = self.text_embedder.embed_single_text(query_text)
image_embeddings = np.stack(list(self.image_data.values()))
similarities = np.dot(image_embeddings, query_embedding)
return {
'mean_similarity': float(np.mean(similarities)),
'max_similarity': float(np.max(similarities)),
'min_similarity': float(np.min(similarities)),
'std_similarity': float(np.std(similarities)),
'median_similarity': float(np.median(similarities))
}
def get_memory_usage(self):
"""Get memory usage information"""
return self.text_embedder.get_memory_usage()
def clear_cache(self):
"""Clear text embedding cache"""
self.text_embedder.clear_cache()
def print_results(results: List[Dict], query: str, show_scores: bool = True):
"""Pretty print search results"""
print(f"\nTop {len(results)} results for: '{query}'")
print("-" * 80)
for i, result in enumerate(results, 1):
path = result['path']
if show_scores and 'score' in result:
score = result['score']
print(f"{i:2d}. {path:<60} score={score:>7.4f}")
else:
print(f"{i:2d}. {path}")
print("-" * 80)
def main():
parser = argparse.ArgumentParser(description="Text-to-Image Search using CLIP embeddings")
# Input/Output paths
parser.add_argument(
"--image-embeddings",
type=str,
default=str(DEFAULT_IMG_EMB_PATH),
help="Path to image embeddings pickle file"
)
# Query options
parser.add_argument(
"--query",
type=str,
help="Single text query to search for"
)
parser.add_argument(
"--queries-file",
type=str,
help="File with one query per line"
)
parser.add_argument(
"--interactive",
action="store_true",
help="Interactive mode for multiple queries"
)
# Search parameters
parser.add_argument(
"--top-k",
type=int,
default=DEFAULT_TOP_K,
help="Number of top results to return"
)
parser.add_argument(
"--text-model",
type=str,
default=DEFAULT_TEXT_MODEL,
help="CLIP text model to use"
)
parser.add_argument(
"--embedding-dim",
type=int,
default=4096,
help="Embedding dimension"
)
# Output options
parser.add_argument(
"--output",
type=str,
help="Save results to JSON file"
)
parser.add_argument(
"--stats",
action="store_true",
help="Show similarity statistics"
)
args = parser.parse_args()
# Initialize searcher
try:
searcher = TextToImageSearcher(
image_embeddings_path=args.image_embeddings,
text_model_name=args.text_model,
embedding_dim=args.embedding_dim
)
except Exception as e:
logger.error(f"Failed to initialize searcher: {e}")
return
# Show memory usage
mem = searcher.get_memory_usage()
logger.info(f"GPU Memory - Allocated: {mem['allocated']:.2f}GB, Reserved: {mem['reserved']:.2f}GB")
# Handle different query modes
if args.query:
# Single query
results = searcher.search(args.query, top_k=args.top_k, return_scores=True)
print_results(results, args.query)
if args.stats:
stats = searcher.get_search_stats(args.query)
print(f"\nSimilarity Statistics:")
for key, value in stats.items():
print(f" {key}: {value:.4f}")
elif args.queries_file:
# Batch queries from file
try:
with open(args.queries_file, 'r', encoding='utf-8') as f:
queries = [line.strip() for line in f if line.strip()]
batch_results = searcher.batch_search(queries, top_k=args.top_k)
for query, results in batch_results.items():
print_results(results, query)
except Exception as e:
logger.error(f"Failed to process queries file: {e}")
return
elif args.interactive:
# Interactive mode
print("\nInteractive Text-to-Image Search")
print("Enter text queries (or 'quit' to exit):")
print("-" * 40)
while True:
try:
query = input("\nQuery: ").strip()
if query.lower() in ['quit', 'exit', 'q']:
break
if not query:
continue
results = searcher.search(query, top_k=args.top_k, return_scores=True)
print_results(results, query)
if args.stats:
stats = searcher.get_search_stats(query)
print(f"\nStats: mean={stats['mean_similarity']:.3f}, "
f"max={stats['max_similarity']:.3f}, "
f"std={stats['std_similarity']:.3f}")
except KeyboardInterrupt:
print("\nExiting...")
break
except Exception as e:
logger.error(f"Error processing query: {e}")
else:
# Demo mode with sample queries
demo_queries = [
"a cat sleeping on a bed",
]
print("Demo mode: Running sample queries...")
batch_results = searcher.batch_search(demo_queries, top_k=5)
for query, results in batch_results.items():
print_results(results, query)
# Save results if requested
if args.output and 'batch_results' in locals():
import json
# Convert results to JSON-serializable format
json_results = {}
for query, results in batch_results.items():
json_results[query] = results # Results are already dict format
try:
with open(args.output, 'w', encoding='utf-8') as f:
json.dump(json_results, f, indent=2, ensure_ascii=False)
logger.info(f"Results saved to {args.output}")
except Exception as e:
logger.error(f"Failed to save results: {e}")
if __name__ == "__main__":
main()