Compare commits
No commits in common. "2103e28aea7e44785790a6a168a42462c24eb79e" and "70b9a14b2ba099c24ac756f8879f87e58da5eb43" have entirely different histories.
2103e28aea
...
70b9a14b2b
182
embedder.py
182
embedder.py
@ -1,31 +1,17 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers import CLIPImageProcessor, CLIPVisionModel
|
||||
from PIL import Image
|
||||
import logging
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Use CUDA if available
|
||||
# Use CUDA
|
||||
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
|
||||
class ImageEmbedder:
|
||||
"""
|
||||
Image embedder that outputs 4096-dimensional embeddings using CLIP vision model.
|
||||
Maintains compatibility with existing LLaVA weights while reducing dimension.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
vision_model_name: str = "openai/clip-vit-large-patch14-336",
|
||||
proj_out_dim: int = 4096, # Changed from 5120 to 4096
|
||||
proj_out_dim: int = 5120,
|
||||
llava_ckpt_path: str = None):
|
||||
|
||||
logger.info(f"Initializing ImageEmbedder with {vision_model_name}")
|
||||
|
||||
# Load CLIP vision encoder + processor (keep exactly the same)
|
||||
# Load CLIP vision encoder + processor
|
||||
self.processor = CLIPImageProcessor.from_pretrained(vision_model_name)
|
||||
self.vision_model = (
|
||||
CLIPVisionModel
|
||||
@ -33,157 +19,37 @@ class ImageEmbedder:
|
||||
.to(DEVICE)
|
||||
.eval()
|
||||
)
|
||||
|
||||
# Freeze vision model parameters
|
||||
# Freeze version
|
||||
for p in self.vision_model.parameters():
|
||||
p.requires_grad = False
|
||||
|
||||
# Get vision model's hidden dimension
|
||||
# Build the projection layer (1024 → proj_out_dim) and move it to DEVICE
|
||||
vision_hidden_dim = self.vision_model.config.hidden_size # should be 1024
|
||||
|
||||
# Keep the simple projection architecture that works!
|
||||
# Just change output dimension from 5120 to 4096
|
||||
self.projection = nn.Linear(vision_hidden_dim, proj_out_dim).to(DEVICE)
|
||||
self.projection = torch.nn.Linear(vision_hidden_dim, proj_out_dim).to(DEVICE)
|
||||
self.projection.eval()
|
||||
|
||||
# Load LLaVA's projection weights if provided
|
||||
# Load LLaVA’s projection weights from the full bin checkpoint
|
||||
if llava_ckpt_path is not None:
|
||||
self._load_llava_weights(llava_ckpt_path, vision_hidden_dim, proj_out_dim)
|
||||
|
||||
logger.info(f"ImageEmbedder ready: {vision_hidden_dim}D -> {proj_out_dim}D on {DEVICE}")
|
||||
|
||||
def _load_llava_weights(self, ckpt_path, vision_dim, output_dim):
|
||||
"""Load LLaVA projection weights, adapting for different output dimension"""
|
||||
try:
|
||||
ckpt = torch.load(ckpt_path, map_location=DEVICE)
|
||||
|
||||
# Extract projector weights (same as your original code)
|
||||
ckpt = torch.load(llava_ckpt_path, map_location=DEVICE)
|
||||
# extract only the projector weights + bias
|
||||
keys = ["model.mm_projector.weight", "model.mm_projector.bias"]
|
||||
state = {
|
||||
k.replace("model.mm_projector.", ""): ckpt[k]
|
||||
for k in keys
|
||||
}
|
||||
# Load into our linear layer
|
||||
self.projection.load_state_dict(state)
|
||||
|
||||
if all(k in ckpt for k in keys):
|
||||
original_weight = ckpt["model.mm_projector.weight"] # Shape: (5120, 1024)
|
||||
original_bias = ckpt["model.mm_projector.bias"] # Shape: (5120,)
|
||||
def image_to_embedding(self, image: Image.Image) -> torch.Tensor:
|
||||
# preprocess & move to device
|
||||
inputs = self.processor(images=image, return_tensors="pt")
|
||||
pixel_values = inputs.pixel_values.to(DEVICE) # (1,3,H,W)
|
||||
|
||||
if output_dim == 5120:
|
||||
# Direct loading for original dimension
|
||||
state = {
|
||||
"weight": original_weight,
|
||||
"bias": original_bias
|
||||
}
|
||||
self.projection.load_state_dict(state)
|
||||
logger.info("Loaded original LLaVA projection weights")
|
||||
|
||||
elif output_dim < 5120:
|
||||
# Truncate weights for smaller dimension (preserves most important features)
|
||||
truncated_weight = original_weight[:output_dim, :] # (4096, 1024)
|
||||
truncated_bias = original_bias[:output_dim] # (4096,)
|
||||
|
||||
state = {
|
||||
"weight": truncated_weight,
|
||||
"bias": truncated_bias
|
||||
}
|
||||
self.projection.load_state_dict(state)
|
||||
logger.info(f"Loaded truncated LLaVA weights: {5120}D -> {output_dim}D")
|
||||
|
||||
else:
|
||||
# For larger dimensions, pad with Xavier initialization
|
||||
new_weight = torch.zeros(output_dim, vision_dim, device=DEVICE)
|
||||
new_bias = torch.zeros(output_dim, device=DEVICE)
|
||||
|
||||
# Copy original weights
|
||||
new_weight[:5120, :] = original_weight
|
||||
new_bias[:5120] = original_bias
|
||||
|
||||
# Initialize additional weights
|
||||
nn.init.xavier_uniform_(new_weight[5120:, :])
|
||||
nn.init.zeros_(new_bias[5120:])
|
||||
|
||||
state = {
|
||||
"weight": new_weight,
|
||||
"bias": new_bias
|
||||
}
|
||||
self.projection.load_state_dict(state)
|
||||
logger.info(f"Loaded extended LLaVA weights: {5120}D -> {output_dim}D")
|
||||
else:
|
||||
logger.warning("LLaVA projector weights not found, using random initialization")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to load LLaVA weights: {e}, using random initialization")
|
||||
|
||||
def image_to_embedding(self, image_input) -> torch.Tensor:
|
||||
"""
|
||||
Convert image(s) to embeddings - keeping your exact original logic!
|
||||
|
||||
Args:
|
||||
image_input: PIL.Image or list of PIL.Images
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Embeddings of shape (proj_out_dim,) for single image
|
||||
or (B, proj_out_dim) for batch of images
|
||||
"""
|
||||
# Handle single image vs batch (same as your original)
|
||||
if isinstance(image_input, Image.Image):
|
||||
images = [image_input]
|
||||
single_image = True
|
||||
else:
|
||||
images = image_input
|
||||
single_image = False
|
||||
|
||||
# Preprocess images (same as your original)
|
||||
inputs = self.processor(images=images, return_tensors="pt")
|
||||
pixel_values = inputs.pixel_values.to(DEVICE) # (B, 3, H, W)
|
||||
|
||||
# Forward pass (exactly like your original)
|
||||
# Forward pass
|
||||
with torch.no_grad():
|
||||
out = self.vision_model(pixel_values=pixel_values)
|
||||
feat = out.pooler_output # (B, 1024)
|
||||
emb = self.projection(feat) # (B, proj_out_dim)
|
||||
feat = out.pooler_output # (1,1024)
|
||||
emb = self.projection(feat) # (1,proj_out_dim)
|
||||
|
||||
# Return appropriate shape (same as your original)
|
||||
if single_image:
|
||||
return emb.squeeze(0) # (proj_out_dim,)
|
||||
else:
|
||||
return emb # (B, proj_out_dim)
|
||||
|
||||
def get_memory_usage(self):
|
||||
"""Get current GPU memory usage if available"""
|
||||
if torch.cuda.is_available():
|
||||
return {
|
||||
'allocated': torch.cuda.memory_allocated() / 1024 ** 3, # GB
|
||||
'reserved': torch.cuda.memory_reserved() / 1024 ** 3, # GB
|
||||
}
|
||||
return {'allocated': 0, 'reserved': 0}
|
||||
|
||||
|
||||
# Test function
|
||||
def main():
|
||||
"""Simple test of the embedder"""
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
parser = argparse.ArgumentParser(description="Test ImageEmbedder")
|
||||
parser.add_argument("--image", type=str, help="Path to test image")
|
||||
parser.add_argument("--model", type=str, default="openai/clip-vit-large-patch14-336")
|
||||
parser.add_argument("--dim", type=int, default=4096)
|
||||
parser.add_argument("--llava-ckpt", type=str, help="Path to LLaVA checkpoint")
|
||||
args = parser.parse_args()
|
||||
|
||||
embedder = ImageEmbedder(
|
||||
vision_model_name=args.model,
|
||||
proj_out_dim=args.dim,
|
||||
llava_ckpt_path=args.llava_ckpt
|
||||
)
|
||||
|
||||
if args.image:
|
||||
img = Image.open(args.image).convert("RGB")
|
||||
emb = embedder.image_to_embedding(img)
|
||||
print(f"Embedding shape: {emb.shape}")
|
||||
print(f"Sample values: {emb[:10]}")
|
||||
|
||||
# Show memory usage
|
||||
mem = embedder.get_memory_usage()
|
||||
logger.info(f"GPU Memory - Allocated: {mem['allocated']:.2f}GB, Reserved: {mem['reserved']:.2f}GB")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
# Returns a 1D tensor of size [proj_out_dim] on DEVICE.
|
||||
return emb.squeeze(0) # → (proj_out_dim,)
|
||||
|
||||
@ -1,444 +1,36 @@
|
||||
import argparse
|
||||
import pickle
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import numpy as np
|
||||
import pickle
|
||||
from PIL import Image
|
||||
from sklearn.metrics.pairwise import cosine_similarity
|
||||
|
||||
from embedder import ImageEmbedder
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Default paths
|
||||
DEFAULT_IMG_EMB_PATH = Path("embeddings/image_embeddings.pkl")
|
||||
DEFAULT_VISION_MODEL = "openai/clip-vit-large-patch14-336"
|
||||
DEFAULT_TOP_K = 10
|
||||
|
||||
|
||||
class ImageSimilaritySearcher:
|
||||
"""
|
||||
Image-to-image similarity search using unified embeddings dictionary
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
image_embeddings_path: Path,
|
||||
vision_model_name: str = DEFAULT_VISION_MODEL,
|
||||
embedding_dim: int = 4096,
|
||||
llava_ckpt_path: Optional[str] = None
|
||||
):
|
||||
self.image_embeddings_path = Path(image_embeddings_path)
|
||||
self.embedding_dim = embedding_dim
|
||||
|
||||
# Load image embeddings
|
||||
logger.info(f"Loading image embeddings from {self.image_embeddings_path}")
|
||||
self.image_data = self._load_image_embeddings()
|
||||
|
||||
# Initialize image embedder for query images
|
||||
logger.info(f"Initializing image embedder: {vision_model_name}")
|
||||
self.image_embedder = ImageEmbedder(
|
||||
vision_model_name=vision_model_name,
|
||||
proj_out_dim=embedding_dim,
|
||||
llava_ckpt_path=llava_ckpt_path
|
||||
)
|
||||
|
||||
logger.info(f"ImageSimilaritySearcher ready with {len(self.image_data)} images")
|
||||
|
||||
def _load_image_embeddings(self) -> Dict[str, np.ndarray]:
|
||||
"""Load image embeddings from unified pickle file"""
|
||||
try:
|
||||
with open(self.image_embeddings_path, "rb") as f:
|
||||
data = pickle.load(f)
|
||||
|
||||
if not isinstance(data, dict):
|
||||
raise ValueError("Image embeddings file should contain a dictionary")
|
||||
|
||||
# Verify embedding dimensions
|
||||
if data:
|
||||
sample_key = next(iter(data))
|
||||
sample_emb = data[sample_key]
|
||||
if sample_emb.shape[-1] != self.embedding_dim:
|
||||
raise ValueError(f"Expected {self.embedding_dim}D embeddings, got {sample_emb.shape[-1]}D")
|
||||
|
||||
logger.info(f"Loaded {len(data)} image embeddings of dimension {sample_emb.shape[-1]}")
|
||||
else:
|
||||
logger.warning("Empty embeddings dictionary loaded")
|
||||
|
||||
return data
|
||||
|
||||
except FileNotFoundError:
|
||||
raise FileNotFoundError(f"Image embeddings file not found: {self.image_embeddings_path}")
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to load image embeddings: {e}")
|
||||
|
||||
def search_by_image_path(
|
||||
self,
|
||||
query_image_path: str,
|
||||
top_k: int = DEFAULT_TOP_K,
|
||||
exclude_self: bool = True,
|
||||
return_scores: bool = False
|
||||
) -> List[Dict]:
|
||||
"""
|
||||
Search for similar images given a query image path
|
||||
|
||||
Args:
|
||||
query_image_path: Path to query image
|
||||
top_k: Number of top results to return
|
||||
exclude_self: Whether to exclude the query image from results
|
||||
return_scores: Whether to include similarity scores
|
||||
|
||||
Returns:
|
||||
List of dictionaries with 'path' and optionally 'score' keys
|
||||
"""
|
||||
logger.info(f"Searching for images similar to: {query_image_path}")
|
||||
|
||||
# Load and embed query image
|
||||
try:
|
||||
query_img = Image.open(query_image_path).convert("RGB")
|
||||
query_embedding = self.image_embedder.image_to_embedding(query_img).cpu().numpy()
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to process query image {query_image_path}: {e}")
|
||||
|
||||
return self._search_by_embedding(
|
||||
query_embedding,
|
||||
query_image_path if exclude_self else None,
|
||||
top_k,
|
||||
return_scores
|
||||
)
|
||||
|
||||
def search_by_image(
|
||||
self,
|
||||
query_image: Image.Image,
|
||||
top_k: int = DEFAULT_TOP_K,
|
||||
return_scores: bool = False
|
||||
) -> List[Dict]:
|
||||
"""
|
||||
Search for similar images given a PIL Image
|
||||
|
||||
Args:
|
||||
query_image: PIL Image object
|
||||
top_k: Number of top results to return
|
||||
return_scores: Whether to include similarity scores
|
||||
|
||||
Returns:
|
||||
List of dictionaries with 'path' and optionally 'score' keys
|
||||
"""
|
||||
logger.info("Searching for similar images using provided image")
|
||||
|
||||
# Embed query image
|
||||
query_embedding = self.image_embedder.image_to_embedding(query_image).cpu().numpy()
|
||||
|
||||
return self._search_by_embedding(query_embedding, None, top_k, return_scores)
|
||||
|
||||
def _search_by_embedding(
|
||||
self,
|
||||
query_embedding: np.ndarray,
|
||||
exclude_path: Optional[str],
|
||||
top_k: int,
|
||||
return_scores: bool
|
||||
) -> List[Dict]:
|
||||
"""Internal method to search by embedding vector"""
|
||||
|
||||
# Prepare database embeddings
|
||||
image_paths = list(self.image_data.keys())
|
||||
image_embeddings = np.stack(list(self.image_data.values())) # (N, D)
|
||||
|
||||
# Compute cosine similarities
|
||||
similarities = cosine_similarity(
|
||||
query_embedding.reshape(1, -1),
|
||||
image_embeddings
|
||||
).flatten()
|
||||
|
||||
# Handle exclusion
|
||||
valid_indices = list(range(len(image_paths)))
|
||||
if exclude_path:
|
||||
try:
|
||||
exclude_idx = image_paths.index(exclude_path)
|
||||
valid_indices.remove(exclude_idx)
|
||||
similarities = np.delete(similarities, exclude_idx)
|
||||
image_paths = [p for i, p in enumerate(image_paths) if i != exclude_idx]
|
||||
except ValueError:
|
||||
pass # Query path not in database
|
||||
|
||||
# Get top-k indices
|
||||
top_k_indices = np.argsort(-similarities)[:top_k]
|
||||
|
||||
# Prepare results
|
||||
results = []
|
||||
for idx in top_k_indices:
|
||||
result = {'path': image_paths[idx]}
|
||||
if return_scores:
|
||||
result['score'] = float(similarities[idx])
|
||||
results.append(result)
|
||||
|
||||
return results
|
||||
|
||||
def batch_search(
|
||||
self,
|
||||
query_image_paths: List[str],
|
||||
top_k: int = DEFAULT_TOP_K,
|
||||
exclude_self: bool = True
|
||||
) -> Dict[str, List[Dict]]:
|
||||
"""
|
||||
Search for multiple query images at once
|
||||
|
||||
Args:
|
||||
query_image_paths: List of paths to query images
|
||||
top_k: Number of results per query
|
||||
exclude_self: Whether to exclude query images from results
|
||||
|
||||
Returns:
|
||||
Dictionary mapping query path to list of results
|
||||
"""
|
||||
logger.info(f"Batch searching {len(query_image_paths)} images")
|
||||
|
||||
results = {}
|
||||
for query_path in query_image_paths:
|
||||
try:
|
||||
results[query_path] = self.search_by_image_path(
|
||||
query_path,
|
||||
top_k=top_k,
|
||||
exclude_self=exclude_self,
|
||||
return_scores=True
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to search for {query_path}: {e}")
|
||||
results[query_path] = []
|
||||
|
||||
return results
|
||||
|
||||
def get_similarity_stats(self, query_image_path: str) -> Dict:
|
||||
"""
|
||||
Get statistics about similarity scores for a query image
|
||||
|
||||
Args:
|
||||
query_image_path: Path to query image
|
||||
|
||||
Returns:
|
||||
Dictionary with similarity statistics
|
||||
"""
|
||||
try:
|
||||
query_img = Image.open(query_image_path).convert("RGB")
|
||||
query_embedding = self.image_embedder.image_to_embedding(query_img).cpu().numpy()
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to process query image: {e}")
|
||||
|
||||
image_embeddings = np.stack(list(self.image_data.values()))
|
||||
similarities = cosine_similarity(
|
||||
query_embedding.reshape(1, -1),
|
||||
image_embeddings
|
||||
).flatten()
|
||||
|
||||
return {
|
||||
'mean_similarity': float(np.mean(similarities)),
|
||||
'max_similarity': float(np.max(similarities)),
|
||||
'min_similarity': float(np.min(similarities)),
|
||||
'std_similarity': float(np.std(similarities)),
|
||||
'median_similarity': float(np.median(similarities))
|
||||
}
|
||||
|
||||
def get_memory_usage(self):
|
||||
"""Get memory usage information"""
|
||||
return self.image_embedder.get_memory_usage()
|
||||
|
||||
|
||||
def print_results(results: List[Dict], query: str, show_scores: bool = True):
|
||||
"""Pretty print search results"""
|
||||
print(f"\nTop {len(results)} similar images to: {Path(query).name}")
|
||||
print(f"Query: {query}")
|
||||
print("-" * 100)
|
||||
|
||||
for i, result in enumerate(results, 1):
|
||||
path = result['path']
|
||||
if show_scores and 'score' in result:
|
||||
score = result['score']
|
||||
print(f"{i:2d}. {path:<80} score={score:>7.4f}")
|
||||
else:
|
||||
print(f"{i:2d}. {path}")
|
||||
|
||||
print("-" * 100)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Image-to-Image Similarity Search")
|
||||
|
||||
# Input/Output paths
|
||||
parser.add_argument(
|
||||
"--image-embeddings",
|
||||
type=str,
|
||||
default=str(DEFAULT_IMG_EMB_PATH),
|
||||
help="Path to image embeddings pickle file"
|
||||
)
|
||||
|
||||
# Query options
|
||||
parser.add_argument(
|
||||
"--query-image",
|
||||
type=str,
|
||||
help="Path to query image"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--query-images-file",
|
||||
type=str,
|
||||
help="File with one image path per line"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--query-dir",
|
||||
type=str,
|
||||
help="Directory containing query images"
|
||||
)
|
||||
|
||||
# Model parameters
|
||||
parser.add_argument(
|
||||
"--vision-model",
|
||||
type=str,
|
||||
default=DEFAULT_VISION_MODEL,
|
||||
help="CLIP vision model to use"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--embedding-dim",
|
||||
type=int,
|
||||
default=4096,
|
||||
help="Embedding dimension"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--llava-ckpt",
|
||||
type=str,
|
||||
help="Path to LLaVA checkpoint for projection weights"
|
||||
)
|
||||
|
||||
# Search parameters
|
||||
parser.add_argument(
|
||||
"--top-k",
|
||||
type=int,
|
||||
default=DEFAULT_TOP_K,
|
||||
help="Number of top results to return"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--include-self",
|
||||
action="store_true",
|
||||
help="Include query image in results"
|
||||
)
|
||||
|
||||
# Output options
|
||||
parser.add_argument(
|
||||
"--output",
|
||||
type=str,
|
||||
help="Save results to JSON file"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--stats",
|
||||
action="store_true",
|
||||
help="Show similarity statistics"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Initialize searcher
|
||||
try:
|
||||
searcher = ImageSimilaritySearcher(
|
||||
image_embeddings_path=args.image_embeddings,
|
||||
vision_model_name=args.vision_model,
|
||||
embedding_dim=args.embedding_dim,
|
||||
llava_ckpt_path=args.llava_ckpt
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize searcher: {e}")
|
||||
return
|
||||
|
||||
# Show memory usage
|
||||
mem = searcher.get_memory_usage()
|
||||
logger.info(f"GPU Memory - Allocated: {mem['allocated']:.2f}GB, Reserved: {mem['reserved']:.2f}GB")
|
||||
|
||||
# Handle different query modes
|
||||
if args.query_image:
|
||||
# Single query image
|
||||
results = searcher.search_by_image_path(
|
||||
args.query_image,
|
||||
top_k=args.top_k,
|
||||
exclude_self=not args.include_self,
|
||||
return_scores=True
|
||||
)
|
||||
print_results(results, args.query_image)
|
||||
|
||||
if args.stats:
|
||||
stats = searcher.get_similarity_stats(args.query_image)
|
||||
print(f"\nSimilarity Statistics:")
|
||||
for key, value in stats.items():
|
||||
print(f" {key}: {value:.4f}")
|
||||
|
||||
elif args.query_images_file:
|
||||
# Batch queries from file
|
||||
try:
|
||||
with open(args.query_images_file, 'r', encoding='utf-8') as f:
|
||||
query_paths = [line.strip() for line in f if line.strip()]
|
||||
|
||||
batch_results = searcher.batch_search(
|
||||
query_paths,
|
||||
top_k=args.top_k,
|
||||
exclude_self=not args.include_self
|
||||
)
|
||||
|
||||
for query_path, results in batch_results.items():
|
||||
if results: # Only show if we got results
|
||||
print_results(results, query_path)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to process queries file: {e}")
|
||||
return
|
||||
|
||||
elif args.query_dir:
|
||||
# Query all images in directory
|
||||
query_dir = Path(args.query_dir)
|
||||
if not query_dir.exists():
|
||||
logger.error(f"Query directory does not exist: {query_dir}")
|
||||
return
|
||||
|
||||
# Find images in directory
|
||||
image_exts = {'.jpg', '.jpeg', '.png', '.bmp', '.gif', '.webp', '.tiff', '.tif'}
|
||||
query_paths = []
|
||||
for ext in image_exts:
|
||||
query_paths.extend(query_dir.glob(f"*{ext}"))
|
||||
query_paths.extend(query_dir.glob(f"*{ext.upper()}"))
|
||||
|
||||
query_paths = [str(p) for p in sorted(query_paths)]
|
||||
|
||||
if not query_paths:
|
||||
logger.error(f"No images found in {query_dir}")
|
||||
return
|
||||
|
||||
logger.info(f"Found {len(query_paths)} images in {query_dir}")
|
||||
|
||||
batch_results = searcher.batch_search(
|
||||
query_paths,
|
||||
top_k=args.top_k,
|
||||
exclude_self=not args.include_self
|
||||
)
|
||||
|
||||
for query_path, results in batch_results.items():
|
||||
if results: # Only show if we got results
|
||||
print_results(results, query_path)
|
||||
|
||||
else:
|
||||
print("Please specify a query image, query images file, or query directory")
|
||||
print("Use --help for more information")
|
||||
return
|
||||
|
||||
# Save results if requested
|
||||
if args.output and 'batch_results' in locals():
|
||||
import json
|
||||
|
||||
try:
|
||||
with open(args.output, 'w', encoding='utf-8') as f:
|
||||
json.dump(batch_results, f, indent=2, ensure_ascii=False)
|
||||
logger.info(f"Results saved to {args.output}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save results: {e}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
# ——— Load stored embeddings & mapping ———
|
||||
vecs = np.load("imgVecs/image_vectors.npy") # (N, D)
|
||||
with open("imgVecs/index_to_file.pkl", "rb") as f:
|
||||
idx2file = pickle.load(f) # dict: idx → filepath
|
||||
|
||||
# ——— Specify query image ———
|
||||
query_path = "datasets/coco/val2017/1140002154.jpg"
|
||||
|
||||
# ——— Embed the query image ———
|
||||
embedder = ImageEmbedder(
|
||||
vision_model_name="openai/clip-vit-large-patch14-336",
|
||||
proj_out_dim=5120,
|
||||
llava_ckpt_path="datasets/pytorch_model-00003-of-00003.bin"
|
||||
)
|
||||
img = Image.open(query_path).convert("RGB")
|
||||
q_vec = embedder.image_to_embedding(img).cpu().numpy() # (D,)
|
||||
|
||||
# ——— Compute similarities & retrieve top-k ———
|
||||
# (you can normalize if you built your DB with inner-product indexing)
|
||||
# vecs_norm = vecs / np.linalg.norm(vecs, axis=1, keepdims=True)
|
||||
# q_vec_norm = q_vec / np.linalg.norm(q_vec)
|
||||
# sims = cosine_similarity(q_vec_norm.reshape(1, -1), vecs_norm).flatten()
|
||||
sims = cosine_similarity(q_vec.reshape(1, -1), vecs).flatten()
|
||||
top5 = sims.argsort()[-5:][::-1]
|
||||
|
||||
# ——— Print out the results ———
|
||||
print(f"Query image: {query_path}\n")
|
||||
for rank, idx in enumerate(top5, 1):
|
||||
print(f"{rank:>2}. {idx2file[idx]} (score: {sims[idx]:.4f})")
|
||||
|
||||
240
starter.py
240
starter.py
@ -1,231 +1,101 @@
|
||||
import os
|
||||
import argparse
|
||||
import pickle
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Dict
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
from atomicwrites import atomic_write
|
||||
|
||||
from embedder import ImageEmbedder, DEVICE
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Default paths
|
||||
DEFAULT_OUTPUT_DIR = Path("embeddings")
|
||||
DEFAULT_EMB_FILE = "image_embeddings.pkl"
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Batch-convert images into 4096D embeddings stored in unified dictionary"
|
||||
p = argparse.ArgumentParser(
|
||||
description="Batch-convert a folder of images into embedding vectors"
|
||||
)
|
||||
parser.add_argument(
|
||||
p.add_argument(
|
||||
"--image-dir", required=True,
|
||||
help="Path to folder containing images (jpg/png/bmp/gif/webp)"
|
||||
help="Path to a folder containing images (jpg/png/bmp/gif)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-dir", default=str(DEFAULT_OUTPUT_DIR),
|
||||
help="Directory to save embeddings file"
|
||||
p.add_argument(
|
||||
"--output-dir", default="processed_images",
|
||||
help="Where to save image_vectors.npy and index_to_file.pkl"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-file", default=DEFAULT_EMB_FILE,
|
||||
help="Filename for embeddings dictionary (pkl format)"
|
||||
)
|
||||
parser.add_argument(
|
||||
p.add_argument(
|
||||
"--vision-model", default="openai/clip-vit-large-patch14-336",
|
||||
help="Hugging Face CLIP vision model name"
|
||||
help="Hugging Face name of the CLIP vision encoder"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--embedding-dim", type=int, default=4096,
|
||||
help="Output embedding dimension"
|
||||
p.add_argument(
|
||||
"--proj-dim", type=int, default=5120,
|
||||
help="Dimensionality of the projection output"
|
||||
)
|
||||
parser.add_argument(
|
||||
p.add_argument(
|
||||
"--llava-ckpt", default=None,
|
||||
help="(Optional) LLaVA checkpoint to load projection weights from"
|
||||
help="(Optional) full LLaVA checkpoint .bin to load projector weights from"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--batch-size", type=int, default=32,
|
||||
help="Batch size for processing images"
|
||||
p.add_argument(
|
||||
"--batch-size", type=int, default=64,
|
||||
help="How many images to encode per GPU/CPU batch"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--resume", action="store_true",
|
||||
help="Resume from existing embeddings file (skip already processed images)"
|
||||
)
|
||||
return parser.parse_args()
|
||||
return p.parse_args()
|
||||
|
||||
|
||||
def find_images(folder: str) -> list:
|
||||
"""Find all supported image files in folder (non-recursive by default)"""
|
||||
supported_exts = (".jpg", ".jpeg", ".png", ".bmp", ".gif", ".webp", ".tiff", ".tif")
|
||||
image_paths = []
|
||||
|
||||
folder_path = Path(folder)
|
||||
|
||||
# Use a set to avoid duplicates
|
||||
found_files = set()
|
||||
|
||||
# Search for files with supported extensions (case-insensitive)
|
||||
for file_path in folder_path.iterdir():
|
||||
if file_path.is_file():
|
||||
file_ext = file_path.suffix.lower()
|
||||
if file_ext in supported_exts:
|
||||
found_files.add(str(file_path.resolve()))
|
||||
|
||||
return sorted(list(found_files))
|
||||
|
||||
|
||||
def load_existing_embeddings(filepath: Path) -> Dict[str, np.ndarray]:
|
||||
"""Load existing embeddings dictionary"""
|
||||
if not filepath.exists():
|
||||
return {}
|
||||
|
||||
try:
|
||||
with open(filepath, "rb") as f:
|
||||
embeddings = pickle.load(f)
|
||||
logger.info(f"Loaded {len(embeddings)} existing embeddings from {filepath}")
|
||||
return embeddings
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to load existing embeddings: {e}")
|
||||
return {}
|
||||
|
||||
|
||||
def save_embeddings(embeddings: Dict[str, np.ndarray], filepath: Path):
|
||||
"""Safely save embeddings dictionary"""
|
||||
filepath.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
with atomic_write(filepath, mode='wb', overwrite=True) as f:
|
||||
pickle.dump(embeddings, f, protocol=pickle.HIGHEST_PROTOCOL)
|
||||
|
||||
logger.info(f"Saved {len(embeddings)} embeddings to {filepath}")
|
||||
def find_images(folder):
|
||||
exts = (".jpg", ".jpeg", ".png", ".bmp", ".gif")
|
||||
return sorted([
|
||||
os.path.join(folder, fname)
|
||||
for fname in os.listdir(folder)
|
||||
if fname.lower().endswith(exts)
|
||||
])
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
|
||||
# Setup paths
|
||||
output_dir = Path(args.output_dir)
|
||||
output_file = output_dir / args.output_file
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
|
||||
# Discover images
|
||||
logger.info(f"Scanning for images in {args.image_dir}")
|
||||
image_paths = find_images(args.image_dir)
|
||||
|
||||
if not image_paths:
|
||||
raise RuntimeError(f"No supported images found in {args.image_dir}")
|
||||
raise RuntimeError(f"No images found in {args.image_dir}")
|
||||
print(f"Found {len(image_paths)} images in {args.image_dir}")
|
||||
|
||||
logger.info(f"Found {len(image_paths)} images")
|
||||
|
||||
# Load existing embeddings if resuming
|
||||
embeddings_dict = {}
|
||||
if args.resume:
|
||||
embeddings_dict = load_existing_embeddings(output_file)
|
||||
|
||||
# Filter out already processed images
|
||||
new_image_paths = [p for p in image_paths if p not in embeddings_dict]
|
||||
logger.info(f"Resuming: {len(embeddings_dict)} already processed, {len(new_image_paths)} remaining")
|
||||
image_paths = new_image_paths
|
||||
|
||||
if not image_paths:
|
||||
logger.info("All images already processed!")
|
||||
return
|
||||
|
||||
# Initialize embedder
|
||||
logger.info(f"Initializing ImageEmbedder on {DEVICE}")
|
||||
# Build embedder
|
||||
embedder = ImageEmbedder(
|
||||
vision_model_name=args.vision_model,
|
||||
proj_out_dim=args.embedding_dim,
|
||||
proj_out_dim=args.proj_dim,
|
||||
llava_ckpt_path=args.llava_ckpt
|
||||
)
|
||||
print(f"Using device: {DEVICE}")
|
||||
|
||||
# Show memory usage
|
||||
mem = embedder.get_memory_usage()
|
||||
logger.info(f"GPU Memory - Allocated: {mem['allocated']:.2f}GB, Reserved: {mem['reserved']:.2f}GB")
|
||||
|
||||
# Process images in batches
|
||||
processed_count = 0
|
||||
failed_count = 0
|
||||
|
||||
progress_bar = tqdm(total=len(image_paths), desc="Processing images")
|
||||
|
||||
# Process in batches
|
||||
all_embs = []
|
||||
index_to_file = {}
|
||||
for batch_start in range(0, len(image_paths), args.batch_size):
|
||||
batch_end = min(batch_start + args.batch_size, len(image_paths))
|
||||
batch_paths = image_paths[batch_start:batch_end]
|
||||
batch_paths = image_paths[batch_start:batch_start + args.batch_size]
|
||||
# load images
|
||||
imgs = [Image.open(p).convert("RGB") for p in batch_paths]
|
||||
# embed
|
||||
with torch.no_grad():
|
||||
embs = embedder.image_to_embedding(imgs) # (B, D)
|
||||
embs_np = embs.cpu().numpy()
|
||||
all_embs.append(embs_np)
|
||||
# record mapping
|
||||
for i, p in enumerate(batch_paths):
|
||||
index_to_file[batch_start + i] = p
|
||||
|
||||
# Load batch images
|
||||
batch_images = []
|
||||
valid_paths = []
|
||||
print(f" • Processed {batch_start + len(batch_paths)}/{len(image_paths)} images")
|
||||
|
||||
for img_path in batch_paths:
|
||||
try:
|
||||
img = Image.open(img_path).convert("RGB")
|
||||
batch_images.append(img)
|
||||
valid_paths.append(img_path)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to load {img_path}: {e}")
|
||||
failed_count += 1
|
||||
progress_bar.update(1)
|
||||
continue
|
||||
# Stack and save
|
||||
vectors = np.vstack(all_embs) # shape (N, D)
|
||||
vec_file = os.path.join(args.output_dir, "image_vectors.npy")
|
||||
map_file = os.path.join(args.output_dir, "index_to_file.pkl")
|
||||
|
||||
if not batch_images:
|
||||
continue
|
||||
np.save(vec_file, vectors)
|
||||
with open(map_file, "wb") as f:
|
||||
pickle.dump(index_to_file, f)
|
||||
|
||||
# Generate embeddings for batch
|
||||
try:
|
||||
with torch.no_grad():
|
||||
if len(batch_images) == 1:
|
||||
embeddings = embedder.image_to_embedding(batch_images[0]).cpu().numpy()
|
||||
embeddings = embeddings.reshape(1, -1) # Make it (1, D)
|
||||
else:
|
||||
embeddings = embedder.image_to_embedding(batch_images).cpu().numpy() # (B, D)
|
||||
|
||||
# Store embeddings in dictionary
|
||||
for img_path, embedding in zip(valid_paths, embeddings):
|
||||
embeddings_dict[img_path] = embedding
|
||||
processed_count += 1
|
||||
|
||||
progress_bar.update(len(valid_paths))
|
||||
|
||||
# Periodic save every 10 batches
|
||||
if (batch_start // args.batch_size) % 10 == 0:
|
||||
save_embeddings(embeddings_dict, output_file)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to process batch starting at {batch_start}: {e}")
|
||||
failed_count += len(batch_paths)
|
||||
progress_bar.update(len(batch_paths))
|
||||
continue
|
||||
|
||||
progress_bar.close()
|
||||
|
||||
# Final save
|
||||
save_embeddings(embeddings_dict, output_file)
|
||||
|
||||
# Summary
|
||||
total_processed = len(embeddings_dict)
|
||||
logger.info(f"\nProcessing complete!")
|
||||
logger.info(f"Total embeddings: {total_processed}")
|
||||
logger.info(f"Newly processed: {processed_count}")
|
||||
logger.info(f"Failed: {failed_count}")
|
||||
logger.info(f"Embedding dimension: {args.embedding_dim}")
|
||||
logger.info(f"Output file: {output_file}")
|
||||
|
||||
# Verify a sample embedding
|
||||
if embeddings_dict:
|
||||
sample_path = next(iter(embeddings_dict))
|
||||
sample_emb = embeddings_dict[sample_path]
|
||||
logger.info(f"Sample embedding shape: {sample_emb.shape}")
|
||||
logger.info(f"Sample embedding norm: {np.linalg.norm(sample_emb):.4f}")
|
||||
|
||||
# Final memory usage
|
||||
mem = embedder.get_memory_usage()
|
||||
logger.info(f"Final GPU Memory - Allocated: {mem['allocated']:.2f}GB, Reserved: {mem['reserved']:.2f}GB")
|
||||
print(f"\nSaved {vectors.shape[0]}×{vectors.shape[1]} vectors to\n {vec_file}")
|
||||
print(f"Saved index→file mapping to\n {map_file}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
306
text_embedder.py
306
text_embedder.py
@ -1,306 +0,0 @@
|
||||
import pickle
|
||||
import hashlib
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import numpy as np
|
||||
from transformers import CLIPTokenizer, CLIPTextModel
|
||||
from cachetools import LRUCache
|
||||
from tqdm import tqdm
|
||||
from atomicwrites import atomic_write
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Device & defaults
|
||||
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
# Use CLIP which is specifically designed for text-image alignment
|
||||
DEFAULT_MODEL = "openai/clip-vit-base-patch32"
|
||||
OUTPUT_DIM = 4096 # Match image embeddings dimension
|
||||
CHUNK_BATCH_SIZE = 16
|
||||
ARTICLE_CACHE_SIZE = 1024
|
||||
|
||||
# Persistence paths
|
||||
EMBEDDINGS_DIR = Path(__file__).parent / "embeddings"
|
||||
TEXT_EMB_PATH = EMBEDDINGS_DIR / "text_embeddings.pkl"
|
||||
|
||||
|
||||
class CLIPTextEmbedder:
|
||||
"""
|
||||
Text embedder using CLIP's text encoder, optimized for text-to-image
|
||||
retrieval with 4096-dimensional output to match image embeddings.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str = DEFAULT_MODEL,
|
||||
output_dim: int = OUTPUT_DIM,
|
||||
chunk_batch_size: int = CHUNK_BATCH_SIZE,
|
||||
):
|
||||
logger.info(f"Initializing CLIPTextEmbedder with {model_name}")
|
||||
|
||||
# Load CLIP text encoder
|
||||
self.tokenizer = CLIPTokenizer.from_pretrained(model_name)
|
||||
self.text_encoder = CLIPTextModel.from_pretrained(model_name).to(DEVICE).eval()
|
||||
|
||||
# Freeze the text encoder
|
||||
for p in self.text_encoder.parameters():
|
||||
p.requires_grad = False
|
||||
|
||||
# Get CLIP's text embedding dimension (usually 512)
|
||||
clip_dim = self.text_encoder.config.hidden_size
|
||||
|
||||
# Enhanced projector to match 4096-D image embeddings
|
||||
self.projector = nn.Sequential(
|
||||
nn.Linear(clip_dim, output_dim // 2),
|
||||
nn.LayerNorm(output_dim // 2),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(0.1),
|
||||
nn.Linear(output_dim // 2, output_dim),
|
||||
nn.LayerNorm(output_dim),
|
||||
nn.ReLU(),
|
||||
nn.Linear(output_dim, output_dim),
|
||||
).to(DEVICE)
|
||||
|
||||
# Initialize projector with Xavier initialization
|
||||
for layer in self.projector:
|
||||
if isinstance(layer, nn.Linear):
|
||||
nn.init.xavier_uniform_(layer.weight)
|
||||
nn.init.zeros_(layer.bias)
|
||||
|
||||
self.article_cache = LRUCache(maxsize=ARTICLE_CACHE_SIZE)
|
||||
self.chunk_batch_size = chunk_batch_size
|
||||
|
||||
# Ensure embeddings dir
|
||||
EMBEDDINGS_DIR.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
logger.info(f"CLIPTextEmbedder ready: {clip_dim}D -> {output_dim}D on {DEVICE}")
|
||||
|
||||
def text_to_embedding(self, text: Union[str, List[str]]) -> torch.Tensor:
|
||||
"""
|
||||
Embed text using CLIP's text encoder + projector
|
||||
|
||||
Args:
|
||||
text: Single string or list of strings
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Normalized embeddings of shape (D,) for single text
|
||||
or (B, D) for batch of texts
|
||||
"""
|
||||
# Handle single text vs batch
|
||||
if isinstance(text, str):
|
||||
texts = [text]
|
||||
single_text = True
|
||||
else:
|
||||
texts = text
|
||||
single_text = False
|
||||
|
||||
# Check cache for single text
|
||||
if single_text and texts[0] in self.article_cache:
|
||||
return self.article_cache[texts[0]]
|
||||
|
||||
# Tokenize and encode with CLIP
|
||||
inputs = self.tokenizer(
|
||||
texts,
|
||||
return_tensors="pt",
|
||||
truncation=True,
|
||||
padding=True,
|
||||
max_length=77 # CLIP's max length
|
||||
).to(DEVICE)
|
||||
|
||||
with torch.no_grad():
|
||||
# Get CLIP text features
|
||||
text_outputs = self.text_encoder(**inputs)
|
||||
# Use pooled output (CLS token equivalent)
|
||||
clip_features = text_outputs.pooler_output # [batch_size, clip_dim]
|
||||
|
||||
# Normalize CLIP features
|
||||
clip_features = nn.functional.normalize(clip_features, p=2, dim=1)
|
||||
|
||||
# Project to target dimension
|
||||
projected = self.projector(clip_features)
|
||||
# Normalize final output
|
||||
projected = nn.functional.normalize(projected, p=2, dim=1)
|
||||
|
||||
# Cache single text result
|
||||
if single_text:
|
||||
result = projected.squeeze(0).cpu()
|
||||
self.article_cache[texts[0]] = result
|
||||
return result
|
||||
else:
|
||||
return projected.cpu() # Return batch
|
||||
|
||||
def embed_texts_from_file(
|
||||
self,
|
||||
text_file: Union[str, Path],
|
||||
save_path: Union[str, Path] = TEXT_EMB_PATH
|
||||
) -> Dict[str, np.ndarray]:
|
||||
"""
|
||||
Read newline-separated texts and embed them
|
||||
"""
|
||||
save_path = Path(save_path)
|
||||
save_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Load existing embeddings
|
||||
try:
|
||||
with open(save_path, "rb") as fp:
|
||||
mapping: Dict[str, np.ndarray] = pickle.load(fp)
|
||||
logger.info(f"Loaded {len(mapping)} existing text embeddings")
|
||||
except FileNotFoundError:
|
||||
mapping = {}
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed loading {save_path}: {e}")
|
||||
mapping = {}
|
||||
|
||||
# Read text file
|
||||
try:
|
||||
with open(text_file, "r", encoding="utf-8") as fp:
|
||||
lines = [ln.strip() for ln in fp if ln.strip()]
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Error reading {text_file}: {e}")
|
||||
|
||||
logger.info(f"Processing {len(lines)} texts from {text_file}")
|
||||
|
||||
# Process in batches
|
||||
new_texts = [ln for ln in lines if ln not in mapping]
|
||||
logger.info(f"Embedding {len(new_texts)} new texts")
|
||||
|
||||
if not new_texts:
|
||||
logger.info("All texts already embedded!")
|
||||
return mapping
|
||||
|
||||
# Process in batches
|
||||
for i in tqdm(range(0, len(new_texts), self.chunk_batch_size), desc="Embedding texts"):
|
||||
batch_texts = new_texts[i:i + self.chunk_batch_size]
|
||||
|
||||
try:
|
||||
# Get batch embeddings
|
||||
batch_embeddings = self.text_to_embedding(batch_texts)
|
||||
|
||||
# Store in mapping
|
||||
for text, embedding in zip(batch_texts, batch_embeddings):
|
||||
mapping[text] = embedding.numpy()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Embedding failed for batch starting at {i}: {e}")
|
||||
# Process individually as fallback
|
||||
for text in batch_texts:
|
||||
try:
|
||||
vec = self.text_to_embedding(text)
|
||||
mapping[text] = vec.numpy()
|
||||
except Exception as e2:
|
||||
logger.error(f"Individual embedding failed for text: {e2}")
|
||||
|
||||
# Save updated mappings
|
||||
with atomic_write(save_path, mode='wb', overwrite=True) as f:
|
||||
pickle.dump(mapping, f, protocol=pickle.HIGHEST_PROTOCOL)
|
||||
|
||||
logger.info(f"Saved {len(mapping)} text embeddings to {save_path}")
|
||||
return mapping
|
||||
|
||||
def embed_single_text(self, text: str) -> np.ndarray:
|
||||
"""Embed a single text and return as numpy array"""
|
||||
embedding = self.text_to_embedding(text)
|
||||
return embedding.numpy()
|
||||
|
||||
def clear_cache(self):
|
||||
"""Clear the text cache"""
|
||||
self.article_cache.clear()
|
||||
logger.info("Text cache cleared")
|
||||
|
||||
def get_memory_usage(self):
|
||||
"""Get current GPU memory usage if available."""
|
||||
if torch.cuda.is_available():
|
||||
return {
|
||||
'allocated': torch.cuda.memory_allocated() / 1024 ** 3, # GB
|
||||
'reserved': torch.cuda.memory_reserved() / 1024 ** 3, # GB
|
||||
}
|
||||
return {'allocated': 0, 'reserved': 0}
|
||||
|
||||
|
||||
# Alias for compatibility
|
||||
TextEmbedder = CLIPTextEmbedder
|
||||
|
||||
|
||||
def main():
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser("CLIP Text Embedder for Image Search")
|
||||
parser.add_argument("--text", type=str, help="Single text to embed")
|
||||
parser.add_argument("--file", type=str, help="Path to newline-delimited texts")
|
||||
parser.add_argument(
|
||||
"--out", type=str, default=str(TEXT_EMB_PATH),
|
||||
help="Pickle path for saving embeddings"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--batch-size", type=int, default=CHUNK_BATCH_SIZE,
|
||||
help="Batch size for processing"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model", type=str, default=DEFAULT_MODEL,
|
||||
help="CLIP model name to use"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dim", type=int, default=OUTPUT_DIM,
|
||||
help="Output embedding dimension"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
embedder = CLIPTextEmbedder(
|
||||
model_name=args.model,
|
||||
output_dim=args.dim,
|
||||
chunk_batch_size=args.batch_size,
|
||||
)
|
||||
|
||||
# Show memory usage
|
||||
mem = embedder.get_memory_usage()
|
||||
logger.info(f"GPU Memory - Allocated: {mem['allocated']:.2f}GB, Reserved: {mem['reserved']:.2f}GB")
|
||||
|
||||
# Single-text mode
|
||||
if args.text:
|
||||
vec = embedder.embed_single_text(args.text)
|
||||
print(f"Text: {args.text}")
|
||||
print(f"Shape: {vec.shape}")
|
||||
print(f"Norm: {np.linalg.norm(vec):.4f}")
|
||||
print(f"Sample values: {vec[:10]}")
|
||||
|
||||
# Save to file
|
||||
save_path = Path(args.out)
|
||||
save_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
try:
|
||||
with open(save_path, "rb") as fp:
|
||||
mapping = pickle.load(fp)
|
||||
except (FileNotFoundError, pickle.PickleError):
|
||||
mapping = {}
|
||||
|
||||
mapping[args.text] = vec
|
||||
with atomic_write(save_path, mode='wb', overwrite=True) as f:
|
||||
pickle.dump(mapping, f, protocol=pickle.HIGHEST_PROTOCOL)
|
||||
logger.info(f"Saved text embedding to {save_path}")
|
||||
|
||||
# Batch-file mode
|
||||
if args.file:
|
||||
embedder.embed_texts_from_file(args.file, save_path=args.out)
|
||||
logger.info(f"Completed batch embedding from {args.file}")
|
||||
|
||||
if not args.text and not args.file:
|
||||
# Demo mode
|
||||
demo_texts = [
|
||||
"a cat sleeping on a bed",
|
||||
"a dog running in the park",
|
||||
"beautiful sunset over mountains",
|
||||
"red sports car on highway"
|
||||
]
|
||||
|
||||
print("Demo: Embedding sample texts...")
|
||||
for text in demo_texts:
|
||||
vec = embedder.embed_single_text(text)
|
||||
print(f"'{text}' -> shape {vec.shape}, norm {np.linalg.norm(vec):.4f}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@ -1,356 +0,0 @@
|
||||
import argparse
|
||||
import pickle
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from text_embedder import CLIPTextEmbedder
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Default paths
|
||||
DEFAULT_IMG_EMB_PATH = Path("embeddings/image_embeddings.pkl")
|
||||
DEFAULT_TEXT_MODEL = "openai/clip-vit-base-patch32"
|
||||
DEFAULT_TOP_K = 10
|
||||
|
||||
|
||||
class TextToImageSearcher:
|
||||
"""
|
||||
Text-to-image search using CLIP embeddings
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
image_embeddings_path: Path,
|
||||
text_model_name: str = DEFAULT_TEXT_MODEL,
|
||||
embedding_dim: int = 4096
|
||||
):
|
||||
"""
|
||||
Initialize the Text-to-Image searcher
|
||||
|
||||
Args:
|
||||
image_embeddings_path: Path to the pickle file containing image embeddings
|
||||
text_model_name: Name of the CLIP text model to use
|
||||
embedding_dim: Dimension of the embeddings (should match image embeddings)
|
||||
"""
|
||||
self.image_embeddings_path = Path(image_embeddings_path)
|
||||
self.embedding_dim = embedding_dim
|
||||
|
||||
# Load image embeddings
|
||||
logger.info(f"Loading image embeddings from {self.image_embeddings_path}")
|
||||
self.image_data = self._load_image_embeddings()
|
||||
|
||||
# Initialize text embedder
|
||||
logger.info(f"Initializing text embedder: {text_model_name}")
|
||||
self.text_embedder = CLIPTextEmbedder(
|
||||
model_name=text_model_name,
|
||||
output_dim=embedding_dim
|
||||
)
|
||||
|
||||
logger.info(f"TextToImageSearcher ready with {len(self.image_data)} images")
|
||||
|
||||
def _load_image_embeddings(self) -> Dict[str, np.ndarray]:
|
||||
"""Load image embeddings from pickle file"""
|
||||
try:
|
||||
with open(self.image_embeddings_path, "rb") as f:
|
||||
data = pickle.load(f)
|
||||
|
||||
if not isinstance(data, dict):
|
||||
raise ValueError("Image embeddings file should contain a dictionary")
|
||||
|
||||
# Verify embedding dimensions
|
||||
sample_key = next(iter(data))
|
||||
sample_emb = data[sample_key]
|
||||
if sample_emb.shape[-1] != self.embedding_dim:
|
||||
raise ValueError(f"Expected {self.embedding_dim}D embeddings, got {sample_emb.shape[-1]}D")
|
||||
|
||||
logger.info(f"Loaded {len(data)} image embeddings of dimension {sample_emb.shape[-1]}")
|
||||
return data
|
||||
|
||||
except FileNotFoundError:
|
||||
raise FileNotFoundError(f"Image embeddings file not found: {self.image_embeddings_path}")
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to load image embeddings: {e}")
|
||||
|
||||
def search(
|
||||
self,
|
||||
query_text: str,
|
||||
top_k: int = DEFAULT_TOP_K,
|
||||
return_scores: bool = False
|
||||
) -> List[Dict]:
|
||||
"""
|
||||
Search for images matching the query text
|
||||
|
||||
Args:
|
||||
query_text: Text description to search for
|
||||
top_k: Number of top results to return
|
||||
return_scores: Whether to include similarity scores
|
||||
|
||||
Returns:
|
||||
List of dictionaries with 'path' and optionally 'score' keys
|
||||
"""
|
||||
logger.info(f"Searching for: '{query_text}'")
|
||||
|
||||
# Embed query text
|
||||
query_embedding = self.text_embedder.embed_single_text(query_text)
|
||||
|
||||
# Prepare image data for comparison
|
||||
image_paths = list(self.image_data.keys())
|
||||
image_embeddings = np.stack(list(self.image_data.values())) # (N, D)
|
||||
|
||||
# Compute cosine similarities (both embeddings are normalized)
|
||||
similarities = np.dot(image_embeddings, query_embedding)
|
||||
|
||||
# Get top-k indices
|
||||
top_k_indices = np.argsort(-similarities)[:top_k]
|
||||
|
||||
# Prepare results
|
||||
results = []
|
||||
for idx in top_k_indices:
|
||||
result = {'path': image_paths[idx]}
|
||||
if return_scores:
|
||||
result['score'] = float(similarities[idx])
|
||||
results.append(result)
|
||||
|
||||
return results
|
||||
|
||||
def batch_search(
|
||||
self,
|
||||
query_texts: List[str],
|
||||
top_k: int = DEFAULT_TOP_K
|
||||
) -> Dict[str, List[Dict]]:
|
||||
"""
|
||||
Search for multiple queries at once
|
||||
|
||||
Args:
|
||||
query_texts: List of text descriptions
|
||||
top_k: Number of results per query
|
||||
|
||||
Returns:
|
||||
Dictionary mapping query text to list of results
|
||||
"""
|
||||
logger.info(f"Batch searching {len(query_texts)} queries")
|
||||
|
||||
results = {}
|
||||
for query in query_texts:
|
||||
results[query] = self.search(query, top_k=top_k, return_scores=True)
|
||||
|
||||
return results
|
||||
|
||||
def get_search_stats(self, query_text: str) -> Dict:
|
||||
"""
|
||||
Get statistics about search results
|
||||
|
||||
Args:
|
||||
query_text: Query to analyze
|
||||
|
||||
Returns:
|
||||
Dictionary with similarity statistics
|
||||
"""
|
||||
query_embedding = self.text_embedder.embed_single_text(query_text)
|
||||
image_embeddings = np.stack(list(self.image_data.values()))
|
||||
similarities = np.dot(image_embeddings, query_embedding)
|
||||
|
||||
return {
|
||||
'mean_similarity': float(np.mean(similarities)),
|
||||
'max_similarity': float(np.max(similarities)),
|
||||
'min_similarity': float(np.min(similarities)),
|
||||
'std_similarity': float(np.std(similarities)),
|
||||
'median_similarity': float(np.median(similarities))
|
||||
}
|
||||
|
||||
def get_memory_usage(self):
|
||||
"""Get memory usage information"""
|
||||
return self.text_embedder.get_memory_usage()
|
||||
|
||||
def clear_cache(self):
|
||||
"""Clear text embedding cache"""
|
||||
self.text_embedder.clear_cache()
|
||||
|
||||
|
||||
def print_results(results: List[Dict], query: str, show_scores: bool = True):
|
||||
"""Pretty print search results"""
|
||||
print(f"\nTop {len(results)} results for: '{query}'")
|
||||
print("-" * 80)
|
||||
|
||||
for i, result in enumerate(results, 1):
|
||||
path = result['path']
|
||||
if show_scores and 'score' in result:
|
||||
score = result['score']
|
||||
print(f"{i:2d}. {path:<60} score={score:>7.4f}")
|
||||
else:
|
||||
print(f"{i:2d}. {path}")
|
||||
|
||||
print("-" * 80)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Text-to-Image Search using CLIP embeddings")
|
||||
|
||||
# Input/Output paths
|
||||
parser.add_argument(
|
||||
"--image-embeddings",
|
||||
type=str,
|
||||
default=str(DEFAULT_IMG_EMB_PATH),
|
||||
help="Path to image embeddings pickle file"
|
||||
)
|
||||
|
||||
# Query options
|
||||
parser.add_argument(
|
||||
"--query",
|
||||
type=str,
|
||||
help="Single text query to search for"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--queries-file",
|
||||
type=str,
|
||||
help="File with one query per line"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--interactive",
|
||||
action="store_true",
|
||||
help="Interactive mode for multiple queries"
|
||||
)
|
||||
|
||||
# Search parameters
|
||||
parser.add_argument(
|
||||
"--top-k",
|
||||
type=int,
|
||||
default=DEFAULT_TOP_K,
|
||||
help="Number of top results to return"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--text-model",
|
||||
type=str,
|
||||
default=DEFAULT_TEXT_MODEL,
|
||||
help="CLIP text model to use"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--embedding-dim",
|
||||
type=int,
|
||||
default=4096,
|
||||
help="Embedding dimension"
|
||||
)
|
||||
|
||||
# Output options
|
||||
parser.add_argument(
|
||||
"--output",
|
||||
type=str,
|
||||
help="Save results to JSON file"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--stats",
|
||||
action="store_true",
|
||||
help="Show similarity statistics"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Initialize searcher
|
||||
try:
|
||||
searcher = TextToImageSearcher(
|
||||
image_embeddings_path=args.image_embeddings,
|
||||
text_model_name=args.text_model,
|
||||
embedding_dim=args.embedding_dim
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize searcher: {e}")
|
||||
return
|
||||
|
||||
# Show memory usage
|
||||
mem = searcher.get_memory_usage()
|
||||
logger.info(f"GPU Memory - Allocated: {mem['allocated']:.2f}GB, Reserved: {mem['reserved']:.2f}GB")
|
||||
|
||||
# Handle different query modes
|
||||
if args.query:
|
||||
# Single query
|
||||
results = searcher.search(args.query, top_k=args.top_k, return_scores=True)
|
||||
print_results(results, args.query)
|
||||
|
||||
if args.stats:
|
||||
stats = searcher.get_search_stats(args.query)
|
||||
print(f"\nSimilarity Statistics:")
|
||||
for key, value in stats.items():
|
||||
print(f" {key}: {value:.4f}")
|
||||
|
||||
elif args.queries_file:
|
||||
# Batch queries from file
|
||||
try:
|
||||
with open(args.queries_file, 'r', encoding='utf-8') as f:
|
||||
queries = [line.strip() for line in f if line.strip()]
|
||||
|
||||
batch_results = searcher.batch_search(queries, top_k=args.top_k)
|
||||
|
||||
for query, results in batch_results.items():
|
||||
print_results(results, query)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to process queries file: {e}")
|
||||
return
|
||||
|
||||
elif args.interactive:
|
||||
# Interactive mode
|
||||
print("\nInteractive Text-to-Image Search")
|
||||
print("Enter text queries (or 'quit' to exit):")
|
||||
print("-" * 40)
|
||||
|
||||
while True:
|
||||
try:
|
||||
query = input("\nQuery: ").strip()
|
||||
if query.lower() in ['quit', 'exit', 'q']:
|
||||
break
|
||||
|
||||
if not query:
|
||||
continue
|
||||
|
||||
results = searcher.search(query, top_k=args.top_k, return_scores=True)
|
||||
print_results(results, query)
|
||||
|
||||
if args.stats:
|
||||
stats = searcher.get_search_stats(query)
|
||||
print(f"\nStats: mean={stats['mean_similarity']:.3f}, "
|
||||
f"max={stats['max_similarity']:.3f}, "
|
||||
f"std={stats['std_similarity']:.3f}")
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\nExiting...")
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing query: {e}")
|
||||
|
||||
else:
|
||||
# Demo mode with sample queries
|
||||
demo_queries = [
|
||||
"a cat sleeping on a bed",
|
||||
]
|
||||
|
||||
print("Demo mode: Running sample queries...")
|
||||
batch_results = searcher.batch_search(demo_queries, top_k=5)
|
||||
|
||||
for query, results in batch_results.items():
|
||||
print_results(results, query)
|
||||
|
||||
# Save results if requested
|
||||
if args.output and 'batch_results' in locals():
|
||||
import json
|
||||
|
||||
# Convert results to JSON-serializable format
|
||||
json_results = {}
|
||||
for query, results in batch_results.items():
|
||||
json_results[query] = results # Results are already dict format
|
||||
|
||||
try:
|
||||
with open(args.output, 'w', encoding='utf-8') as f:
|
||||
json.dump(json_results, f, indent=2, ensure_ascii=False)
|
||||
logger.info(f"Results saved to {args.output}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save results: {e}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Loading…
x
Reference in New Issue
Block a user