Compare commits
No commits in common. "4096Attempt" and "main" have entirely different histories.
4096Attemp
...
main
176
embedder.py
176
embedder.py
@ -1,31 +1,17 @@
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
|
||||||
from transformers import CLIPImageProcessor, CLIPVisionModel
|
from transformers import CLIPImageProcessor, CLIPVisionModel
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
import logging
|
|
||||||
|
|
||||||
# Configure logging
|
# Use CUDA
|
||||||
logging.basicConfig(level=logging.INFO)
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
# Use CUDA if available
|
|
||||||
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
|
||||||
|
|
||||||
class ImageEmbedder:
|
class ImageEmbedder:
|
||||||
"""
|
|
||||||
Image embedder that outputs 4096-dimensional embeddings using CLIP vision model.
|
|
||||||
Maintains compatibility with existing LLaVA weights while reducing dimension.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
vision_model_name: str = "openai/clip-vit-large-patch14-336",
|
vision_model_name: str = "openai/clip-vit-large-patch14-336",
|
||||||
proj_out_dim: int = 4096, # Changed from 5120 to 4096
|
proj_out_dim: int = 5120,
|
||||||
llava_ckpt_path: str = None):
|
llava_ckpt_path: str = None):
|
||||||
|
# Load CLIP vision encoder + processor
|
||||||
logger.info(f"Initializing ImageEmbedder with {vision_model_name}")
|
|
||||||
|
|
||||||
# Load CLIP vision encoder + processor (keep exactly the same)
|
|
||||||
self.processor = CLIPImageProcessor.from_pretrained(vision_model_name)
|
self.processor = CLIPImageProcessor.from_pretrained(vision_model_name)
|
||||||
self.vision_model = (
|
self.vision_model = (
|
||||||
CLIPVisionModel
|
CLIPVisionModel
|
||||||
@ -33,157 +19,37 @@ class ImageEmbedder:
|
|||||||
.to(DEVICE)
|
.to(DEVICE)
|
||||||
.eval()
|
.eval()
|
||||||
)
|
)
|
||||||
|
# Freeze version
|
||||||
# Freeze vision model parameters
|
|
||||||
for p in self.vision_model.parameters():
|
for p in self.vision_model.parameters():
|
||||||
p.requires_grad = False
|
p.requires_grad = False
|
||||||
|
|
||||||
# Get vision model's hidden dimension
|
# Build the projection layer (1024 → proj_out_dim) and move it to DEVICE
|
||||||
vision_hidden_dim = self.vision_model.config.hidden_size # should be 1024
|
vision_hidden_dim = self.vision_model.config.hidden_size # should be 1024
|
||||||
|
self.projection = torch.nn.Linear(vision_hidden_dim, proj_out_dim).to(DEVICE)
|
||||||
# Keep the simple projection architecture that works!
|
|
||||||
# Just change output dimension from 5120 to 4096
|
|
||||||
self.projection = nn.Linear(vision_hidden_dim, proj_out_dim).to(DEVICE)
|
|
||||||
self.projection.eval()
|
self.projection.eval()
|
||||||
|
|
||||||
# Load LLaVA's projection weights if provided
|
# Load LLaVA’s projection weights from the full bin checkpoint
|
||||||
if llava_ckpt_path is not None:
|
if llava_ckpt_path is not None:
|
||||||
self._load_llava_weights(llava_ckpt_path, vision_hidden_dim, proj_out_dim)
|
ckpt = torch.load(llava_ckpt_path, map_location=DEVICE)
|
||||||
|
# extract only the projector weights + bias
|
||||||
logger.info(f"ImageEmbedder ready: {vision_hidden_dim}D -> {proj_out_dim}D on {DEVICE}")
|
|
||||||
|
|
||||||
def _load_llava_weights(self, ckpt_path, vision_dim, output_dim):
|
|
||||||
"""Load LLaVA projection weights, adapting for different output dimension"""
|
|
||||||
try:
|
|
||||||
ckpt = torch.load(ckpt_path, map_location=DEVICE)
|
|
||||||
|
|
||||||
# Extract projector weights (same as your original code)
|
|
||||||
keys = ["model.mm_projector.weight", "model.mm_projector.bias"]
|
keys = ["model.mm_projector.weight", "model.mm_projector.bias"]
|
||||||
|
|
||||||
if all(k in ckpt for k in keys):
|
|
||||||
original_weight = ckpt["model.mm_projector.weight"] # Shape: (5120, 1024)
|
|
||||||
original_bias = ckpt["model.mm_projector.bias"] # Shape: (5120,)
|
|
||||||
|
|
||||||
if output_dim == 5120:
|
|
||||||
# Direct loading for original dimension
|
|
||||||
state = {
|
state = {
|
||||||
"weight": original_weight,
|
k.replace("model.mm_projector.", ""): ckpt[k]
|
||||||
"bias": original_bias
|
for k in keys
|
||||||
}
|
}
|
||||||
|
# Load into our linear layer
|
||||||
self.projection.load_state_dict(state)
|
self.projection.load_state_dict(state)
|
||||||
logger.info("Loaded original LLaVA projection weights")
|
|
||||||
|
|
||||||
elif output_dim < 5120:
|
def image_to_embedding(self, image: Image.Image) -> torch.Tensor:
|
||||||
# Truncate weights for smaller dimension (preserves most important features)
|
# preprocess & move to device
|
||||||
truncated_weight = original_weight[:output_dim, :] # (4096, 1024)
|
inputs = self.processor(images=image, return_tensors="pt")
|
||||||
truncated_bias = original_bias[:output_dim] # (4096,)
|
pixel_values = inputs.pixel_values.to(DEVICE) # (1,3,H,W)
|
||||||
|
|
||||||
state = {
|
# Forward pass
|
||||||
"weight": truncated_weight,
|
|
||||||
"bias": truncated_bias
|
|
||||||
}
|
|
||||||
self.projection.load_state_dict(state)
|
|
||||||
logger.info(f"Loaded truncated LLaVA weights: {5120}D -> {output_dim}D")
|
|
||||||
|
|
||||||
else:
|
|
||||||
# For larger dimensions, pad with Xavier initialization
|
|
||||||
new_weight = torch.zeros(output_dim, vision_dim, device=DEVICE)
|
|
||||||
new_bias = torch.zeros(output_dim, device=DEVICE)
|
|
||||||
|
|
||||||
# Copy original weights
|
|
||||||
new_weight[:5120, :] = original_weight
|
|
||||||
new_bias[:5120] = original_bias
|
|
||||||
|
|
||||||
# Initialize additional weights
|
|
||||||
nn.init.xavier_uniform_(new_weight[5120:, :])
|
|
||||||
nn.init.zeros_(new_bias[5120:])
|
|
||||||
|
|
||||||
state = {
|
|
||||||
"weight": new_weight,
|
|
||||||
"bias": new_bias
|
|
||||||
}
|
|
||||||
self.projection.load_state_dict(state)
|
|
||||||
logger.info(f"Loaded extended LLaVA weights: {5120}D -> {output_dim}D")
|
|
||||||
else:
|
|
||||||
logger.warning("LLaVA projector weights not found, using random initialization")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Failed to load LLaVA weights: {e}, using random initialization")
|
|
||||||
|
|
||||||
def image_to_embedding(self, image_input) -> torch.Tensor:
|
|
||||||
"""
|
|
||||||
Convert image(s) to embeddings - keeping your exact original logic!
|
|
||||||
|
|
||||||
Args:
|
|
||||||
image_input: PIL.Image or list of PIL.Images
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
torch.Tensor: Embeddings of shape (proj_out_dim,) for single image
|
|
||||||
or (B, proj_out_dim) for batch of images
|
|
||||||
"""
|
|
||||||
# Handle single image vs batch (same as your original)
|
|
||||||
if isinstance(image_input, Image.Image):
|
|
||||||
images = [image_input]
|
|
||||||
single_image = True
|
|
||||||
else:
|
|
||||||
images = image_input
|
|
||||||
single_image = False
|
|
||||||
|
|
||||||
# Preprocess images (same as your original)
|
|
||||||
inputs = self.processor(images=images, return_tensors="pt")
|
|
||||||
pixel_values = inputs.pixel_values.to(DEVICE) # (B, 3, H, W)
|
|
||||||
|
|
||||||
# Forward pass (exactly like your original)
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
out = self.vision_model(pixel_values=pixel_values)
|
out = self.vision_model(pixel_values=pixel_values)
|
||||||
feat = out.pooler_output # (B, 1024)
|
feat = out.pooler_output # (1,1024)
|
||||||
emb = self.projection(feat) # (B, proj_out_dim)
|
emb = self.projection(feat) # (1,proj_out_dim)
|
||||||
|
|
||||||
# Return appropriate shape (same as your original)
|
# Returns a 1D tensor of size [proj_out_dim] on DEVICE.
|
||||||
if single_image:
|
return emb.squeeze(0) # → (proj_out_dim,)
|
||||||
return emb.squeeze(0) # (proj_out_dim,)
|
|
||||||
else:
|
|
||||||
return emb # (B, proj_out_dim)
|
|
||||||
|
|
||||||
def get_memory_usage(self):
|
|
||||||
"""Get current GPU memory usage if available"""
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
return {
|
|
||||||
'allocated': torch.cuda.memory_allocated() / 1024 ** 3, # GB
|
|
||||||
'reserved': torch.cuda.memory_reserved() / 1024 ** 3, # GB
|
|
||||||
}
|
|
||||||
return {'allocated': 0, 'reserved': 0}
|
|
||||||
|
|
||||||
|
|
||||||
# Test function
|
|
||||||
def main():
|
|
||||||
"""Simple test of the embedder"""
|
|
||||||
import argparse
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(description="Test ImageEmbedder")
|
|
||||||
parser.add_argument("--image", type=str, help="Path to test image")
|
|
||||||
parser.add_argument("--model", type=str, default="openai/clip-vit-large-patch14-336")
|
|
||||||
parser.add_argument("--dim", type=int, default=4096)
|
|
||||||
parser.add_argument("--llava-ckpt", type=str, help="Path to LLaVA checkpoint")
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
embedder = ImageEmbedder(
|
|
||||||
vision_model_name=args.model,
|
|
||||||
proj_out_dim=args.dim,
|
|
||||||
llava_ckpt_path=args.llava_ckpt
|
|
||||||
)
|
|
||||||
|
|
||||||
if args.image:
|
|
||||||
img = Image.open(args.image).convert("RGB")
|
|
||||||
emb = embedder.image_to_embedding(img)
|
|
||||||
print(f"Embedding shape: {emb.shape}")
|
|
||||||
print(f"Sample values: {emb[:10]}")
|
|
||||||
|
|
||||||
# Show memory usage
|
|
||||||
mem = embedder.get_memory_usage()
|
|
||||||
logger.info(f"GPU Memory - Allocated: {mem['allocated']:.2f}GB, Reserved: {mem['reserved']:.2f}GB")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
||||||
|
|||||||
@ -1,444 +1,36 @@
|
|||||||
import argparse
|
|
||||||
import pickle
|
|
||||||
import logging
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Dict, List, Optional
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import pickle
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from sklearn.metrics.pairwise import cosine_similarity
|
from sklearn.metrics.pairwise import cosine_similarity
|
||||||
|
|
||||||
from embedder import ImageEmbedder
|
from embedder import ImageEmbedder
|
||||||
|
|
||||||
# Configure logging
|
# ——— Load stored embeddings & mapping ———
|
||||||
logging.basicConfig(level=logging.INFO)
|
vecs = np.load("imgVecs/image_vectors.npy") # (N, D)
|
||||||
logger = logging.getLogger(__name__)
|
with open("imgVecs/index_to_file.pkl", "rb") as f:
|
||||||
|
idx2file = pickle.load(f) # dict: idx → filepath
|
||||||
|
|
||||||
# Default paths
|
# ——— Specify query image ———
|
||||||
DEFAULT_IMG_EMB_PATH = Path("embeddings/image_embeddings.pkl")
|
query_path = "datasets/coco/val2017/1140002154.jpg"
|
||||||
DEFAULT_VISION_MODEL = "openai/clip-vit-large-patch14-336"
|
|
||||||
DEFAULT_TOP_K = 10
|
|
||||||
|
|
||||||
|
# ——— Embed the query image ———
|
||||||
class ImageSimilaritySearcher:
|
embedder = ImageEmbedder(
|
||||||
"""
|
vision_model_name="openai/clip-vit-large-patch14-336",
|
||||||
Image-to-image similarity search using unified embeddings dictionary
|
proj_out_dim=5120,
|
||||||
"""
|
llava_ckpt_path="datasets/pytorch_model-00003-of-00003.bin"
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
image_embeddings_path: Path,
|
|
||||||
vision_model_name: str = DEFAULT_VISION_MODEL,
|
|
||||||
embedding_dim: int = 4096,
|
|
||||||
llava_ckpt_path: Optional[str] = None
|
|
||||||
):
|
|
||||||
self.image_embeddings_path = Path(image_embeddings_path)
|
|
||||||
self.embedding_dim = embedding_dim
|
|
||||||
|
|
||||||
# Load image embeddings
|
|
||||||
logger.info(f"Loading image embeddings from {self.image_embeddings_path}")
|
|
||||||
self.image_data = self._load_image_embeddings()
|
|
||||||
|
|
||||||
# Initialize image embedder for query images
|
|
||||||
logger.info(f"Initializing image embedder: {vision_model_name}")
|
|
||||||
self.image_embedder = ImageEmbedder(
|
|
||||||
vision_model_name=vision_model_name,
|
|
||||||
proj_out_dim=embedding_dim,
|
|
||||||
llava_ckpt_path=llava_ckpt_path
|
|
||||||
)
|
)
|
||||||
|
img = Image.open(query_path).convert("RGB")
|
||||||
logger.info(f"ImageSimilaritySearcher ready with {len(self.image_data)} images")
|
q_vec = embedder.image_to_embedding(img).cpu().numpy() # (D,)
|
||||||
|
|
||||||
def _load_image_embeddings(self) -> Dict[str, np.ndarray]:
|
# ——— Compute similarities & retrieve top-k ———
|
||||||
"""Load image embeddings from unified pickle file"""
|
# (you can normalize if you built your DB with inner-product indexing)
|
||||||
try:
|
# vecs_norm = vecs / np.linalg.norm(vecs, axis=1, keepdims=True)
|
||||||
with open(self.image_embeddings_path, "rb") as f:
|
# q_vec_norm = q_vec / np.linalg.norm(q_vec)
|
||||||
data = pickle.load(f)
|
# sims = cosine_similarity(q_vec_norm.reshape(1, -1), vecs_norm).flatten()
|
||||||
|
sims = cosine_similarity(q_vec.reshape(1, -1), vecs).flatten()
|
||||||
if not isinstance(data, dict):
|
top5 = sims.argsort()[-5:][::-1]
|
||||||
raise ValueError("Image embeddings file should contain a dictionary")
|
|
||||||
|
# ——— Print out the results ———
|
||||||
# Verify embedding dimensions
|
print(f"Query image: {query_path}\n")
|
||||||
if data:
|
for rank, idx in enumerate(top5, 1):
|
||||||
sample_key = next(iter(data))
|
print(f"{rank:>2}. {idx2file[idx]} (score: {sims[idx]:.4f})")
|
||||||
sample_emb = data[sample_key]
|
|
||||||
if sample_emb.shape[-1] != self.embedding_dim:
|
|
||||||
raise ValueError(f"Expected {self.embedding_dim}D embeddings, got {sample_emb.shape[-1]}D")
|
|
||||||
|
|
||||||
logger.info(f"Loaded {len(data)} image embeddings of dimension {sample_emb.shape[-1]}")
|
|
||||||
else:
|
|
||||||
logger.warning("Empty embeddings dictionary loaded")
|
|
||||||
|
|
||||||
return data
|
|
||||||
|
|
||||||
except FileNotFoundError:
|
|
||||||
raise FileNotFoundError(f"Image embeddings file not found: {self.image_embeddings_path}")
|
|
||||||
except Exception as e:
|
|
||||||
raise RuntimeError(f"Failed to load image embeddings: {e}")
|
|
||||||
|
|
||||||
def search_by_image_path(
|
|
||||||
self,
|
|
||||||
query_image_path: str,
|
|
||||||
top_k: int = DEFAULT_TOP_K,
|
|
||||||
exclude_self: bool = True,
|
|
||||||
return_scores: bool = False
|
|
||||||
) -> List[Dict]:
|
|
||||||
"""
|
|
||||||
Search for similar images given a query image path
|
|
||||||
|
|
||||||
Args:
|
|
||||||
query_image_path: Path to query image
|
|
||||||
top_k: Number of top results to return
|
|
||||||
exclude_self: Whether to exclude the query image from results
|
|
||||||
return_scores: Whether to include similarity scores
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of dictionaries with 'path' and optionally 'score' keys
|
|
||||||
"""
|
|
||||||
logger.info(f"Searching for images similar to: {query_image_path}")
|
|
||||||
|
|
||||||
# Load and embed query image
|
|
||||||
try:
|
|
||||||
query_img = Image.open(query_image_path).convert("RGB")
|
|
||||||
query_embedding = self.image_embedder.image_to_embedding(query_img).cpu().numpy()
|
|
||||||
except Exception as e:
|
|
||||||
raise RuntimeError(f"Failed to process query image {query_image_path}: {e}")
|
|
||||||
|
|
||||||
return self._search_by_embedding(
|
|
||||||
query_embedding,
|
|
||||||
query_image_path if exclude_self else None,
|
|
||||||
top_k,
|
|
||||||
return_scores
|
|
||||||
)
|
|
||||||
|
|
||||||
def search_by_image(
|
|
||||||
self,
|
|
||||||
query_image: Image.Image,
|
|
||||||
top_k: int = DEFAULT_TOP_K,
|
|
||||||
return_scores: bool = False
|
|
||||||
) -> List[Dict]:
|
|
||||||
"""
|
|
||||||
Search for similar images given a PIL Image
|
|
||||||
|
|
||||||
Args:
|
|
||||||
query_image: PIL Image object
|
|
||||||
top_k: Number of top results to return
|
|
||||||
return_scores: Whether to include similarity scores
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of dictionaries with 'path' and optionally 'score' keys
|
|
||||||
"""
|
|
||||||
logger.info("Searching for similar images using provided image")
|
|
||||||
|
|
||||||
# Embed query image
|
|
||||||
query_embedding = self.image_embedder.image_to_embedding(query_image).cpu().numpy()
|
|
||||||
|
|
||||||
return self._search_by_embedding(query_embedding, None, top_k, return_scores)
|
|
||||||
|
|
||||||
def _search_by_embedding(
|
|
||||||
self,
|
|
||||||
query_embedding: np.ndarray,
|
|
||||||
exclude_path: Optional[str],
|
|
||||||
top_k: int,
|
|
||||||
return_scores: bool
|
|
||||||
) -> List[Dict]:
|
|
||||||
"""Internal method to search by embedding vector"""
|
|
||||||
|
|
||||||
# Prepare database embeddings
|
|
||||||
image_paths = list(self.image_data.keys())
|
|
||||||
image_embeddings = np.stack(list(self.image_data.values())) # (N, D)
|
|
||||||
|
|
||||||
# Compute cosine similarities
|
|
||||||
similarities = cosine_similarity(
|
|
||||||
query_embedding.reshape(1, -1),
|
|
||||||
image_embeddings
|
|
||||||
).flatten()
|
|
||||||
|
|
||||||
# Handle exclusion
|
|
||||||
valid_indices = list(range(len(image_paths)))
|
|
||||||
if exclude_path:
|
|
||||||
try:
|
|
||||||
exclude_idx = image_paths.index(exclude_path)
|
|
||||||
valid_indices.remove(exclude_idx)
|
|
||||||
similarities = np.delete(similarities, exclude_idx)
|
|
||||||
image_paths = [p for i, p in enumerate(image_paths) if i != exclude_idx]
|
|
||||||
except ValueError:
|
|
||||||
pass # Query path not in database
|
|
||||||
|
|
||||||
# Get top-k indices
|
|
||||||
top_k_indices = np.argsort(-similarities)[:top_k]
|
|
||||||
|
|
||||||
# Prepare results
|
|
||||||
results = []
|
|
||||||
for idx in top_k_indices:
|
|
||||||
result = {'path': image_paths[idx]}
|
|
||||||
if return_scores:
|
|
||||||
result['score'] = float(similarities[idx])
|
|
||||||
results.append(result)
|
|
||||||
|
|
||||||
return results
|
|
||||||
|
|
||||||
def batch_search(
|
|
||||||
self,
|
|
||||||
query_image_paths: List[str],
|
|
||||||
top_k: int = DEFAULT_TOP_K,
|
|
||||||
exclude_self: bool = True
|
|
||||||
) -> Dict[str, List[Dict]]:
|
|
||||||
"""
|
|
||||||
Search for multiple query images at once
|
|
||||||
|
|
||||||
Args:
|
|
||||||
query_image_paths: List of paths to query images
|
|
||||||
top_k: Number of results per query
|
|
||||||
exclude_self: Whether to exclude query images from results
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Dictionary mapping query path to list of results
|
|
||||||
"""
|
|
||||||
logger.info(f"Batch searching {len(query_image_paths)} images")
|
|
||||||
|
|
||||||
results = {}
|
|
||||||
for query_path in query_image_paths:
|
|
||||||
try:
|
|
||||||
results[query_path] = self.search_by_image_path(
|
|
||||||
query_path,
|
|
||||||
top_k=top_k,
|
|
||||||
exclude_self=exclude_self,
|
|
||||||
return_scores=True
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Failed to search for {query_path}: {e}")
|
|
||||||
results[query_path] = []
|
|
||||||
|
|
||||||
return results
|
|
||||||
|
|
||||||
def get_similarity_stats(self, query_image_path: str) -> Dict:
|
|
||||||
"""
|
|
||||||
Get statistics about similarity scores for a query image
|
|
||||||
|
|
||||||
Args:
|
|
||||||
query_image_path: Path to query image
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Dictionary with similarity statistics
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
query_img = Image.open(query_image_path).convert("RGB")
|
|
||||||
query_embedding = self.image_embedder.image_to_embedding(query_img).cpu().numpy()
|
|
||||||
except Exception as e:
|
|
||||||
raise RuntimeError(f"Failed to process query image: {e}")
|
|
||||||
|
|
||||||
image_embeddings = np.stack(list(self.image_data.values()))
|
|
||||||
similarities = cosine_similarity(
|
|
||||||
query_embedding.reshape(1, -1),
|
|
||||||
image_embeddings
|
|
||||||
).flatten()
|
|
||||||
|
|
||||||
return {
|
|
||||||
'mean_similarity': float(np.mean(similarities)),
|
|
||||||
'max_similarity': float(np.max(similarities)),
|
|
||||||
'min_similarity': float(np.min(similarities)),
|
|
||||||
'std_similarity': float(np.std(similarities)),
|
|
||||||
'median_similarity': float(np.median(similarities))
|
|
||||||
}
|
|
||||||
|
|
||||||
def get_memory_usage(self):
|
|
||||||
"""Get memory usage information"""
|
|
||||||
return self.image_embedder.get_memory_usage()
|
|
||||||
|
|
||||||
|
|
||||||
def print_results(results: List[Dict], query: str, show_scores: bool = True):
|
|
||||||
"""Pretty print search results"""
|
|
||||||
print(f"\nTop {len(results)} similar images to: {Path(query).name}")
|
|
||||||
print(f"Query: {query}")
|
|
||||||
print("-" * 100)
|
|
||||||
|
|
||||||
for i, result in enumerate(results, 1):
|
|
||||||
path = result['path']
|
|
||||||
if show_scores and 'score' in result:
|
|
||||||
score = result['score']
|
|
||||||
print(f"{i:2d}. {path:<80} score={score:>7.4f}")
|
|
||||||
else:
|
|
||||||
print(f"{i:2d}. {path}")
|
|
||||||
|
|
||||||
print("-" * 100)
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
parser = argparse.ArgumentParser(description="Image-to-Image Similarity Search")
|
|
||||||
|
|
||||||
# Input/Output paths
|
|
||||||
parser.add_argument(
|
|
||||||
"--image-embeddings",
|
|
||||||
type=str,
|
|
||||||
default=str(DEFAULT_IMG_EMB_PATH),
|
|
||||||
help="Path to image embeddings pickle file"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Query options
|
|
||||||
parser.add_argument(
|
|
||||||
"--query-image",
|
|
||||||
type=str,
|
|
||||||
help="Path to query image"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--query-images-file",
|
|
||||||
type=str,
|
|
||||||
help="File with one image path per line"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--query-dir",
|
|
||||||
type=str,
|
|
||||||
help="Directory containing query images"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Model parameters
|
|
||||||
parser.add_argument(
|
|
||||||
"--vision-model",
|
|
||||||
type=str,
|
|
||||||
default=DEFAULT_VISION_MODEL,
|
|
||||||
help="CLIP vision model to use"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--embedding-dim",
|
|
||||||
type=int,
|
|
||||||
default=4096,
|
|
||||||
help="Embedding dimension"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--llava-ckpt",
|
|
||||||
type=str,
|
|
||||||
help="Path to LLaVA checkpoint for projection weights"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Search parameters
|
|
||||||
parser.add_argument(
|
|
||||||
"--top-k",
|
|
||||||
type=int,
|
|
||||||
default=DEFAULT_TOP_K,
|
|
||||||
help="Number of top results to return"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--include-self",
|
|
||||||
action="store_true",
|
|
||||||
help="Include query image in results"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Output options
|
|
||||||
parser.add_argument(
|
|
||||||
"--output",
|
|
||||||
type=str,
|
|
||||||
help="Save results to JSON file"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--stats",
|
|
||||||
action="store_true",
|
|
||||||
help="Show similarity statistics"
|
|
||||||
)
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
# Initialize searcher
|
|
||||||
try:
|
|
||||||
searcher = ImageSimilaritySearcher(
|
|
||||||
image_embeddings_path=args.image_embeddings,
|
|
||||||
vision_model_name=args.vision_model,
|
|
||||||
embedding_dim=args.embedding_dim,
|
|
||||||
llava_ckpt_path=args.llava_ckpt
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Failed to initialize searcher: {e}")
|
|
||||||
return
|
|
||||||
|
|
||||||
# Show memory usage
|
|
||||||
mem = searcher.get_memory_usage()
|
|
||||||
logger.info(f"GPU Memory - Allocated: {mem['allocated']:.2f}GB, Reserved: {mem['reserved']:.2f}GB")
|
|
||||||
|
|
||||||
# Handle different query modes
|
|
||||||
if args.query_image:
|
|
||||||
# Single query image
|
|
||||||
results = searcher.search_by_image_path(
|
|
||||||
args.query_image,
|
|
||||||
top_k=args.top_k,
|
|
||||||
exclude_self=not args.include_self,
|
|
||||||
return_scores=True
|
|
||||||
)
|
|
||||||
print_results(results, args.query_image)
|
|
||||||
|
|
||||||
if args.stats:
|
|
||||||
stats = searcher.get_similarity_stats(args.query_image)
|
|
||||||
print(f"\nSimilarity Statistics:")
|
|
||||||
for key, value in stats.items():
|
|
||||||
print(f" {key}: {value:.4f}")
|
|
||||||
|
|
||||||
elif args.query_images_file:
|
|
||||||
# Batch queries from file
|
|
||||||
try:
|
|
||||||
with open(args.query_images_file, 'r', encoding='utf-8') as f:
|
|
||||||
query_paths = [line.strip() for line in f if line.strip()]
|
|
||||||
|
|
||||||
batch_results = searcher.batch_search(
|
|
||||||
query_paths,
|
|
||||||
top_k=args.top_k,
|
|
||||||
exclude_self=not args.include_self
|
|
||||||
)
|
|
||||||
|
|
||||||
for query_path, results in batch_results.items():
|
|
||||||
if results: # Only show if we got results
|
|
||||||
print_results(results, query_path)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Failed to process queries file: {e}")
|
|
||||||
return
|
|
||||||
|
|
||||||
elif args.query_dir:
|
|
||||||
# Query all images in directory
|
|
||||||
query_dir = Path(args.query_dir)
|
|
||||||
if not query_dir.exists():
|
|
||||||
logger.error(f"Query directory does not exist: {query_dir}")
|
|
||||||
return
|
|
||||||
|
|
||||||
# Find images in directory
|
|
||||||
image_exts = {'.jpg', '.jpeg', '.png', '.bmp', '.gif', '.webp', '.tiff', '.tif'}
|
|
||||||
query_paths = []
|
|
||||||
for ext in image_exts:
|
|
||||||
query_paths.extend(query_dir.glob(f"*{ext}"))
|
|
||||||
query_paths.extend(query_dir.glob(f"*{ext.upper()}"))
|
|
||||||
|
|
||||||
query_paths = [str(p) for p in sorted(query_paths)]
|
|
||||||
|
|
||||||
if not query_paths:
|
|
||||||
logger.error(f"No images found in {query_dir}")
|
|
||||||
return
|
|
||||||
|
|
||||||
logger.info(f"Found {len(query_paths)} images in {query_dir}")
|
|
||||||
|
|
||||||
batch_results = searcher.batch_search(
|
|
||||||
query_paths,
|
|
||||||
top_k=args.top_k,
|
|
||||||
exclude_self=not args.include_self
|
|
||||||
)
|
|
||||||
|
|
||||||
for query_path, results in batch_results.items():
|
|
||||||
if results: # Only show if we got results
|
|
||||||
print_results(results, query_path)
|
|
||||||
|
|
||||||
else:
|
|
||||||
print("Please specify a query image, query images file, or query directory")
|
|
||||||
print("Use --help for more information")
|
|
||||||
return
|
|
||||||
|
|
||||||
# Save results if requested
|
|
||||||
if args.output and 'batch_results' in locals():
|
|
||||||
import json
|
|
||||||
|
|
||||||
try:
|
|
||||||
with open(args.output, 'w', encoding='utf-8') as f:
|
|
||||||
json.dump(batch_results, f, indent=2, ensure_ascii=False)
|
|
||||||
logger.info(f"Results saved to {args.output}")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Failed to save results: {e}")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
||||||
|
|||||||
238
starter.py
238
starter.py
@ -1,231 +1,101 @@
|
|||||||
import os
|
import os
|
||||||
import argparse
|
import argparse
|
||||||
import pickle
|
import pickle
|
||||||
import logging
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Dict
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
import torch
|
import torch
|
||||||
from tqdm import tqdm
|
|
||||||
from atomicwrites import atomic_write
|
|
||||||
|
|
||||||
from embedder import ImageEmbedder, DEVICE
|
from embedder import ImageEmbedder, DEVICE
|
||||||
|
|
||||||
# Configure logging
|
|
||||||
logging.basicConfig(level=logging.INFO)
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
# Default paths
|
|
||||||
DEFAULT_OUTPUT_DIR = Path("embeddings")
|
|
||||||
DEFAULT_EMB_FILE = "image_embeddings.pkl"
|
|
||||||
|
|
||||||
|
|
||||||
def parse_args():
|
def parse_args():
|
||||||
parser = argparse.ArgumentParser(
|
p = argparse.ArgumentParser(
|
||||||
description="Batch-convert images into 4096D embeddings stored in unified dictionary"
|
description="Batch-convert a folder of images into embedding vectors"
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
p.add_argument(
|
||||||
"--image-dir", required=True,
|
"--image-dir", required=True,
|
||||||
help="Path to folder containing images (jpg/png/bmp/gif/webp)"
|
help="Path to a folder containing images (jpg/png/bmp/gif)"
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
p.add_argument(
|
||||||
"--output-dir", default=str(DEFAULT_OUTPUT_DIR),
|
"--output-dir", default="processed_images",
|
||||||
help="Directory to save embeddings file"
|
help="Where to save image_vectors.npy and index_to_file.pkl"
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
p.add_argument(
|
||||||
"--output-file", default=DEFAULT_EMB_FILE,
|
|
||||||
help="Filename for embeddings dictionary (pkl format)"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--vision-model", default="openai/clip-vit-large-patch14-336",
|
"--vision-model", default="openai/clip-vit-large-patch14-336",
|
||||||
help="Hugging Face CLIP vision model name"
|
help="Hugging Face name of the CLIP vision encoder"
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
p.add_argument(
|
||||||
"--embedding-dim", type=int, default=4096,
|
"--proj-dim", type=int, default=5120,
|
||||||
help="Output embedding dimension"
|
help="Dimensionality of the projection output"
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
p.add_argument(
|
||||||
"--llava-ckpt", default=None,
|
"--llava-ckpt", default=None,
|
||||||
help="(Optional) LLaVA checkpoint to load projection weights from"
|
help="(Optional) full LLaVA checkpoint .bin to load projector weights from"
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
p.add_argument(
|
||||||
"--batch-size", type=int, default=32,
|
"--batch-size", type=int, default=64,
|
||||||
help="Batch size for processing images"
|
help="How many images to encode per GPU/CPU batch"
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
return p.parse_args()
|
||||||
"--resume", action="store_true",
|
|
||||||
help="Resume from existing embeddings file (skip already processed images)"
|
|
||||||
)
|
|
||||||
return parser.parse_args()
|
|
||||||
|
|
||||||
|
|
||||||
def find_images(folder: str) -> list:
|
def find_images(folder):
|
||||||
"""Find all supported image files in folder (non-recursive by default)"""
|
exts = (".jpg", ".jpeg", ".png", ".bmp", ".gif")
|
||||||
supported_exts = (".jpg", ".jpeg", ".png", ".bmp", ".gif", ".webp", ".tiff", ".tif")
|
return sorted([
|
||||||
image_paths = []
|
os.path.join(folder, fname)
|
||||||
|
for fname in os.listdir(folder)
|
||||||
folder_path = Path(folder)
|
if fname.lower().endswith(exts)
|
||||||
|
])
|
||||||
# Use a set to avoid duplicates
|
|
||||||
found_files = set()
|
|
||||||
|
|
||||||
# Search for files with supported extensions (case-insensitive)
|
|
||||||
for file_path in folder_path.iterdir():
|
|
||||||
if file_path.is_file():
|
|
||||||
file_ext = file_path.suffix.lower()
|
|
||||||
if file_ext in supported_exts:
|
|
||||||
found_files.add(str(file_path.resolve()))
|
|
||||||
|
|
||||||
return sorted(list(found_files))
|
|
||||||
|
|
||||||
|
|
||||||
def load_existing_embeddings(filepath: Path) -> Dict[str, np.ndarray]:
|
|
||||||
"""Load existing embeddings dictionary"""
|
|
||||||
if not filepath.exists():
|
|
||||||
return {}
|
|
||||||
|
|
||||||
try:
|
|
||||||
with open(filepath, "rb") as f:
|
|
||||||
embeddings = pickle.load(f)
|
|
||||||
logger.info(f"Loaded {len(embeddings)} existing embeddings from {filepath}")
|
|
||||||
return embeddings
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Failed to load existing embeddings: {e}")
|
|
||||||
return {}
|
|
||||||
|
|
||||||
|
|
||||||
def save_embeddings(embeddings: Dict[str, np.ndarray], filepath: Path):
|
|
||||||
"""Safely save embeddings dictionary"""
|
|
||||||
filepath.parent.mkdir(parents=True, exist_ok=True)
|
|
||||||
|
|
||||||
with atomic_write(filepath, mode='wb', overwrite=True) as f:
|
|
||||||
pickle.dump(embeddings, f, protocol=pickle.HIGHEST_PROTOCOL)
|
|
||||||
|
|
||||||
logger.info(f"Saved {len(embeddings)} embeddings to {filepath}")
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
args = parse_args()
|
args = parse_args()
|
||||||
|
os.makedirs(args.output_dir, exist_ok=True)
|
||||||
# Setup paths
|
|
||||||
output_dir = Path(args.output_dir)
|
|
||||||
output_file = output_dir / args.output_file
|
|
||||||
|
|
||||||
# Discover images
|
# Discover images
|
||||||
logger.info(f"Scanning for images in {args.image_dir}")
|
|
||||||
image_paths = find_images(args.image_dir)
|
image_paths = find_images(args.image_dir)
|
||||||
|
|
||||||
if not image_paths:
|
if not image_paths:
|
||||||
raise RuntimeError(f"No supported images found in {args.image_dir}")
|
raise RuntimeError(f"No images found in {args.image_dir}")
|
||||||
|
print(f"Found {len(image_paths)} images in {args.image_dir}")
|
||||||
|
|
||||||
logger.info(f"Found {len(image_paths)} images")
|
# Build embedder
|
||||||
|
|
||||||
# Load existing embeddings if resuming
|
|
||||||
embeddings_dict = {}
|
|
||||||
if args.resume:
|
|
||||||
embeddings_dict = load_existing_embeddings(output_file)
|
|
||||||
|
|
||||||
# Filter out already processed images
|
|
||||||
new_image_paths = [p for p in image_paths if p not in embeddings_dict]
|
|
||||||
logger.info(f"Resuming: {len(embeddings_dict)} already processed, {len(new_image_paths)} remaining")
|
|
||||||
image_paths = new_image_paths
|
|
||||||
|
|
||||||
if not image_paths:
|
|
||||||
logger.info("All images already processed!")
|
|
||||||
return
|
|
||||||
|
|
||||||
# Initialize embedder
|
|
||||||
logger.info(f"Initializing ImageEmbedder on {DEVICE}")
|
|
||||||
embedder = ImageEmbedder(
|
embedder = ImageEmbedder(
|
||||||
vision_model_name=args.vision_model,
|
vision_model_name=args.vision_model,
|
||||||
proj_out_dim=args.embedding_dim,
|
proj_out_dim=args.proj_dim,
|
||||||
llava_ckpt_path=args.llava_ckpt
|
llava_ckpt_path=args.llava_ckpt
|
||||||
)
|
)
|
||||||
|
print(f"Using device: {DEVICE}")
|
||||||
|
|
||||||
# Show memory usage
|
# Process in batches
|
||||||
mem = embedder.get_memory_usage()
|
all_embs = []
|
||||||
logger.info(f"GPU Memory - Allocated: {mem['allocated']:.2f}GB, Reserved: {mem['reserved']:.2f}GB")
|
index_to_file = {}
|
||||||
|
|
||||||
# Process images in batches
|
|
||||||
processed_count = 0
|
|
||||||
failed_count = 0
|
|
||||||
|
|
||||||
progress_bar = tqdm(total=len(image_paths), desc="Processing images")
|
|
||||||
|
|
||||||
for batch_start in range(0, len(image_paths), args.batch_size):
|
for batch_start in range(0, len(image_paths), args.batch_size):
|
||||||
batch_end = min(batch_start + args.batch_size, len(image_paths))
|
batch_paths = image_paths[batch_start:batch_start + args.batch_size]
|
||||||
batch_paths = image_paths[batch_start:batch_end]
|
# load images
|
||||||
|
imgs = [Image.open(p).convert("RGB") for p in batch_paths]
|
||||||
# Load batch images
|
# embed
|
||||||
batch_images = []
|
|
||||||
valid_paths = []
|
|
||||||
|
|
||||||
for img_path in batch_paths:
|
|
||||||
try:
|
|
||||||
img = Image.open(img_path).convert("RGB")
|
|
||||||
batch_images.append(img)
|
|
||||||
valid_paths.append(img_path)
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Failed to load {img_path}: {e}")
|
|
||||||
failed_count += 1
|
|
||||||
progress_bar.update(1)
|
|
||||||
continue
|
|
||||||
|
|
||||||
if not batch_images:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Generate embeddings for batch
|
|
||||||
try:
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
if len(batch_images) == 1:
|
embs = embedder.image_to_embedding(imgs) # (B, D)
|
||||||
embeddings = embedder.image_to_embedding(batch_images[0]).cpu().numpy()
|
embs_np = embs.cpu().numpy()
|
||||||
embeddings = embeddings.reshape(1, -1) # Make it (1, D)
|
all_embs.append(embs_np)
|
||||||
else:
|
# record mapping
|
||||||
embeddings = embedder.image_to_embedding(batch_images).cpu().numpy() # (B, D)
|
for i, p in enumerate(batch_paths):
|
||||||
|
index_to_file[batch_start + i] = p
|
||||||
|
|
||||||
# Store embeddings in dictionary
|
print(f" • Processed {batch_start + len(batch_paths)}/{len(image_paths)} images")
|
||||||
for img_path, embedding in zip(valid_paths, embeddings):
|
|
||||||
embeddings_dict[img_path] = embedding
|
|
||||||
processed_count += 1
|
|
||||||
|
|
||||||
progress_bar.update(len(valid_paths))
|
# Stack and save
|
||||||
|
vectors = np.vstack(all_embs) # shape (N, D)
|
||||||
|
vec_file = os.path.join(args.output_dir, "image_vectors.npy")
|
||||||
|
map_file = os.path.join(args.output_dir, "index_to_file.pkl")
|
||||||
|
|
||||||
# Periodic save every 10 batches
|
np.save(vec_file, vectors)
|
||||||
if (batch_start // args.batch_size) % 10 == 0:
|
with open(map_file, "wb") as f:
|
||||||
save_embeddings(embeddings_dict, output_file)
|
pickle.dump(index_to_file, f)
|
||||||
|
|
||||||
except Exception as e:
|
print(f"\nSaved {vectors.shape[0]}×{vectors.shape[1]} vectors to\n {vec_file}")
|
||||||
logger.error(f"Failed to process batch starting at {batch_start}: {e}")
|
print(f"Saved index→file mapping to\n {map_file}")
|
||||||
failed_count += len(batch_paths)
|
|
||||||
progress_bar.update(len(batch_paths))
|
|
||||||
continue
|
|
||||||
|
|
||||||
progress_bar.close()
|
|
||||||
|
|
||||||
# Final save
|
|
||||||
save_embeddings(embeddings_dict, output_file)
|
|
||||||
|
|
||||||
# Summary
|
|
||||||
total_processed = len(embeddings_dict)
|
|
||||||
logger.info(f"\nProcessing complete!")
|
|
||||||
logger.info(f"Total embeddings: {total_processed}")
|
|
||||||
logger.info(f"Newly processed: {processed_count}")
|
|
||||||
logger.info(f"Failed: {failed_count}")
|
|
||||||
logger.info(f"Embedding dimension: {args.embedding_dim}")
|
|
||||||
logger.info(f"Output file: {output_file}")
|
|
||||||
|
|
||||||
# Verify a sample embedding
|
|
||||||
if embeddings_dict:
|
|
||||||
sample_path = next(iter(embeddings_dict))
|
|
||||||
sample_emb = embeddings_dict[sample_path]
|
|
||||||
logger.info(f"Sample embedding shape: {sample_emb.shape}")
|
|
||||||
logger.info(f"Sample embedding norm: {np.linalg.norm(sample_emb):.4f}")
|
|
||||||
|
|
||||||
# Final memory usage
|
|
||||||
mem = embedder.get_memory_usage()
|
|
||||||
logger.info(f"Final GPU Memory - Allocated: {mem['allocated']:.2f}GB, Reserved: {mem['reserved']:.2f}GB")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
306
text_embedder.py
306
text_embedder.py
@ -1,306 +0,0 @@
|
|||||||
import pickle
|
|
||||||
import hashlib
|
|
||||||
import logging
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Dict, List, Union
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import numpy as np
|
|
||||||
from transformers import CLIPTokenizer, CLIPTextModel
|
|
||||||
from cachetools import LRUCache
|
|
||||||
from tqdm import tqdm
|
|
||||||
from atomicwrites import atomic_write
|
|
||||||
|
|
||||||
# Configure logging
|
|
||||||
logging.basicConfig(level=logging.INFO)
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
# Device & defaults
|
|
||||||
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
||||||
# Use CLIP which is specifically designed for text-image alignment
|
|
||||||
DEFAULT_MODEL = "openai/clip-vit-base-patch32"
|
|
||||||
OUTPUT_DIM = 4096 # Match image embeddings dimension
|
|
||||||
CHUNK_BATCH_SIZE = 16
|
|
||||||
ARTICLE_CACHE_SIZE = 1024
|
|
||||||
|
|
||||||
# Persistence paths
|
|
||||||
EMBEDDINGS_DIR = Path(__file__).parent / "embeddings"
|
|
||||||
TEXT_EMB_PATH = EMBEDDINGS_DIR / "text_embeddings.pkl"
|
|
||||||
|
|
||||||
|
|
||||||
class CLIPTextEmbedder:
|
|
||||||
"""
|
|
||||||
Text embedder using CLIP's text encoder, optimized for text-to-image
|
|
||||||
retrieval with 4096-dimensional output to match image embeddings.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
model_name: str = DEFAULT_MODEL,
|
|
||||||
output_dim: int = OUTPUT_DIM,
|
|
||||||
chunk_batch_size: int = CHUNK_BATCH_SIZE,
|
|
||||||
):
|
|
||||||
logger.info(f"Initializing CLIPTextEmbedder with {model_name}")
|
|
||||||
|
|
||||||
# Load CLIP text encoder
|
|
||||||
self.tokenizer = CLIPTokenizer.from_pretrained(model_name)
|
|
||||||
self.text_encoder = CLIPTextModel.from_pretrained(model_name).to(DEVICE).eval()
|
|
||||||
|
|
||||||
# Freeze the text encoder
|
|
||||||
for p in self.text_encoder.parameters():
|
|
||||||
p.requires_grad = False
|
|
||||||
|
|
||||||
# Get CLIP's text embedding dimension (usually 512)
|
|
||||||
clip_dim = self.text_encoder.config.hidden_size
|
|
||||||
|
|
||||||
# Enhanced projector to match 4096-D image embeddings
|
|
||||||
self.projector = nn.Sequential(
|
|
||||||
nn.Linear(clip_dim, output_dim // 2),
|
|
||||||
nn.LayerNorm(output_dim // 2),
|
|
||||||
nn.ReLU(),
|
|
||||||
nn.Dropout(0.1),
|
|
||||||
nn.Linear(output_dim // 2, output_dim),
|
|
||||||
nn.LayerNorm(output_dim),
|
|
||||||
nn.ReLU(),
|
|
||||||
nn.Linear(output_dim, output_dim),
|
|
||||||
).to(DEVICE)
|
|
||||||
|
|
||||||
# Initialize projector with Xavier initialization
|
|
||||||
for layer in self.projector:
|
|
||||||
if isinstance(layer, nn.Linear):
|
|
||||||
nn.init.xavier_uniform_(layer.weight)
|
|
||||||
nn.init.zeros_(layer.bias)
|
|
||||||
|
|
||||||
self.article_cache = LRUCache(maxsize=ARTICLE_CACHE_SIZE)
|
|
||||||
self.chunk_batch_size = chunk_batch_size
|
|
||||||
|
|
||||||
# Ensure embeddings dir
|
|
||||||
EMBEDDINGS_DIR.mkdir(parents=True, exist_ok=True)
|
|
||||||
|
|
||||||
logger.info(f"CLIPTextEmbedder ready: {clip_dim}D -> {output_dim}D on {DEVICE}")
|
|
||||||
|
|
||||||
def text_to_embedding(self, text: Union[str, List[str]]) -> torch.Tensor:
|
|
||||||
"""
|
|
||||||
Embed text using CLIP's text encoder + projector
|
|
||||||
|
|
||||||
Args:
|
|
||||||
text: Single string or list of strings
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
torch.Tensor: Normalized embeddings of shape (D,) for single text
|
|
||||||
or (B, D) for batch of texts
|
|
||||||
"""
|
|
||||||
# Handle single text vs batch
|
|
||||||
if isinstance(text, str):
|
|
||||||
texts = [text]
|
|
||||||
single_text = True
|
|
||||||
else:
|
|
||||||
texts = text
|
|
||||||
single_text = False
|
|
||||||
|
|
||||||
# Check cache for single text
|
|
||||||
if single_text and texts[0] in self.article_cache:
|
|
||||||
return self.article_cache[texts[0]]
|
|
||||||
|
|
||||||
# Tokenize and encode with CLIP
|
|
||||||
inputs = self.tokenizer(
|
|
||||||
texts,
|
|
||||||
return_tensors="pt",
|
|
||||||
truncation=True,
|
|
||||||
padding=True,
|
|
||||||
max_length=77 # CLIP's max length
|
|
||||||
).to(DEVICE)
|
|
||||||
|
|
||||||
with torch.no_grad():
|
|
||||||
# Get CLIP text features
|
|
||||||
text_outputs = self.text_encoder(**inputs)
|
|
||||||
# Use pooled output (CLS token equivalent)
|
|
||||||
clip_features = text_outputs.pooler_output # [batch_size, clip_dim]
|
|
||||||
|
|
||||||
# Normalize CLIP features
|
|
||||||
clip_features = nn.functional.normalize(clip_features, p=2, dim=1)
|
|
||||||
|
|
||||||
# Project to target dimension
|
|
||||||
projected = self.projector(clip_features)
|
|
||||||
# Normalize final output
|
|
||||||
projected = nn.functional.normalize(projected, p=2, dim=1)
|
|
||||||
|
|
||||||
# Cache single text result
|
|
||||||
if single_text:
|
|
||||||
result = projected.squeeze(0).cpu()
|
|
||||||
self.article_cache[texts[0]] = result
|
|
||||||
return result
|
|
||||||
else:
|
|
||||||
return projected.cpu() # Return batch
|
|
||||||
|
|
||||||
def embed_texts_from_file(
|
|
||||||
self,
|
|
||||||
text_file: Union[str, Path],
|
|
||||||
save_path: Union[str, Path] = TEXT_EMB_PATH
|
|
||||||
) -> Dict[str, np.ndarray]:
|
|
||||||
"""
|
|
||||||
Read newline-separated texts and embed them
|
|
||||||
"""
|
|
||||||
save_path = Path(save_path)
|
|
||||||
save_path.parent.mkdir(parents=True, exist_ok=True)
|
|
||||||
|
|
||||||
# Load existing embeddings
|
|
||||||
try:
|
|
||||||
with open(save_path, "rb") as fp:
|
|
||||||
mapping: Dict[str, np.ndarray] = pickle.load(fp)
|
|
||||||
logger.info(f"Loaded {len(mapping)} existing text embeddings")
|
|
||||||
except FileNotFoundError:
|
|
||||||
mapping = {}
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Failed loading {save_path}: {e}")
|
|
||||||
mapping = {}
|
|
||||||
|
|
||||||
# Read text file
|
|
||||||
try:
|
|
||||||
with open(text_file, "r", encoding="utf-8") as fp:
|
|
||||||
lines = [ln.strip() for ln in fp if ln.strip()]
|
|
||||||
except Exception as e:
|
|
||||||
raise RuntimeError(f"Error reading {text_file}: {e}")
|
|
||||||
|
|
||||||
logger.info(f"Processing {len(lines)} texts from {text_file}")
|
|
||||||
|
|
||||||
# Process in batches
|
|
||||||
new_texts = [ln for ln in lines if ln not in mapping]
|
|
||||||
logger.info(f"Embedding {len(new_texts)} new texts")
|
|
||||||
|
|
||||||
if not new_texts:
|
|
||||||
logger.info("All texts already embedded!")
|
|
||||||
return mapping
|
|
||||||
|
|
||||||
# Process in batches
|
|
||||||
for i in tqdm(range(0, len(new_texts), self.chunk_batch_size), desc="Embedding texts"):
|
|
||||||
batch_texts = new_texts[i:i + self.chunk_batch_size]
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Get batch embeddings
|
|
||||||
batch_embeddings = self.text_to_embedding(batch_texts)
|
|
||||||
|
|
||||||
# Store in mapping
|
|
||||||
for text, embedding in zip(batch_texts, batch_embeddings):
|
|
||||||
mapping[text] = embedding.numpy()
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Embedding failed for batch starting at {i}: {e}")
|
|
||||||
# Process individually as fallback
|
|
||||||
for text in batch_texts:
|
|
||||||
try:
|
|
||||||
vec = self.text_to_embedding(text)
|
|
||||||
mapping[text] = vec.numpy()
|
|
||||||
except Exception as e2:
|
|
||||||
logger.error(f"Individual embedding failed for text: {e2}")
|
|
||||||
|
|
||||||
# Save updated mappings
|
|
||||||
with atomic_write(save_path, mode='wb', overwrite=True) as f:
|
|
||||||
pickle.dump(mapping, f, protocol=pickle.HIGHEST_PROTOCOL)
|
|
||||||
|
|
||||||
logger.info(f"Saved {len(mapping)} text embeddings to {save_path}")
|
|
||||||
return mapping
|
|
||||||
|
|
||||||
def embed_single_text(self, text: str) -> np.ndarray:
|
|
||||||
"""Embed a single text and return as numpy array"""
|
|
||||||
embedding = self.text_to_embedding(text)
|
|
||||||
return embedding.numpy()
|
|
||||||
|
|
||||||
def clear_cache(self):
|
|
||||||
"""Clear the text cache"""
|
|
||||||
self.article_cache.clear()
|
|
||||||
logger.info("Text cache cleared")
|
|
||||||
|
|
||||||
def get_memory_usage(self):
|
|
||||||
"""Get current GPU memory usage if available."""
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
return {
|
|
||||||
'allocated': torch.cuda.memory_allocated() / 1024 ** 3, # GB
|
|
||||||
'reserved': torch.cuda.memory_reserved() / 1024 ** 3, # GB
|
|
||||||
}
|
|
||||||
return {'allocated': 0, 'reserved': 0}
|
|
||||||
|
|
||||||
|
|
||||||
# Alias for compatibility
|
|
||||||
TextEmbedder = CLIPTextEmbedder
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
import argparse
|
|
||||||
|
|
||||||
parser = argparse.ArgumentParser("CLIP Text Embedder for Image Search")
|
|
||||||
parser.add_argument("--text", type=str, help="Single text to embed")
|
|
||||||
parser.add_argument("--file", type=str, help="Path to newline-delimited texts")
|
|
||||||
parser.add_argument(
|
|
||||||
"--out", type=str, default=str(TEXT_EMB_PATH),
|
|
||||||
help="Pickle path for saving embeddings"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--batch-size", type=int, default=CHUNK_BATCH_SIZE,
|
|
||||||
help="Batch size for processing"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--model", type=str, default=DEFAULT_MODEL,
|
|
||||||
help="CLIP model name to use"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--dim", type=int, default=OUTPUT_DIM,
|
|
||||||
help="Output embedding dimension"
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
embedder = CLIPTextEmbedder(
|
|
||||||
model_name=args.model,
|
|
||||||
output_dim=args.dim,
|
|
||||||
chunk_batch_size=args.batch_size,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Show memory usage
|
|
||||||
mem = embedder.get_memory_usage()
|
|
||||||
logger.info(f"GPU Memory - Allocated: {mem['allocated']:.2f}GB, Reserved: {mem['reserved']:.2f}GB")
|
|
||||||
|
|
||||||
# Single-text mode
|
|
||||||
if args.text:
|
|
||||||
vec = embedder.embed_single_text(args.text)
|
|
||||||
print(f"Text: {args.text}")
|
|
||||||
print(f"Shape: {vec.shape}")
|
|
||||||
print(f"Norm: {np.linalg.norm(vec):.4f}")
|
|
||||||
print(f"Sample values: {vec[:10]}")
|
|
||||||
|
|
||||||
# Save to file
|
|
||||||
save_path = Path(args.out)
|
|
||||||
save_path.parent.mkdir(parents=True, exist_ok=True)
|
|
||||||
try:
|
|
||||||
with open(save_path, "rb") as fp:
|
|
||||||
mapping = pickle.load(fp)
|
|
||||||
except (FileNotFoundError, pickle.PickleError):
|
|
||||||
mapping = {}
|
|
||||||
|
|
||||||
mapping[args.text] = vec
|
|
||||||
with atomic_write(save_path, mode='wb', overwrite=True) as f:
|
|
||||||
pickle.dump(mapping, f, protocol=pickle.HIGHEST_PROTOCOL)
|
|
||||||
logger.info(f"Saved text embedding to {save_path}")
|
|
||||||
|
|
||||||
# Batch-file mode
|
|
||||||
if args.file:
|
|
||||||
embedder.embed_texts_from_file(args.file, save_path=args.out)
|
|
||||||
logger.info(f"Completed batch embedding from {args.file}")
|
|
||||||
|
|
||||||
if not args.text and not args.file:
|
|
||||||
# Demo mode
|
|
||||||
demo_texts = [
|
|
||||||
"a cat sleeping on a bed",
|
|
||||||
"a dog running in the park",
|
|
||||||
"beautiful sunset over mountains",
|
|
||||||
"red sports car on highway"
|
|
||||||
]
|
|
||||||
|
|
||||||
print("Demo: Embedding sample texts...")
|
|
||||||
for text in demo_texts:
|
|
||||||
vec = embedder.embed_single_text(text)
|
|
||||||
print(f"'{text}' -> shape {vec.shape}, norm {np.linalg.norm(vec):.4f}")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
||||||
@ -1,356 +0,0 @@
|
|||||||
import argparse
|
|
||||||
import pickle
|
|
||||||
import logging
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Dict, List, Optional
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from text_embedder import CLIPTextEmbedder
|
|
||||||
|
|
||||||
# Configure logging
|
|
||||||
logging.basicConfig(level=logging.INFO)
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
# Default paths
|
|
||||||
DEFAULT_IMG_EMB_PATH = Path("embeddings/image_embeddings.pkl")
|
|
||||||
DEFAULT_TEXT_MODEL = "openai/clip-vit-base-patch32"
|
|
||||||
DEFAULT_TOP_K = 10
|
|
||||||
|
|
||||||
|
|
||||||
class TextToImageSearcher:
|
|
||||||
"""
|
|
||||||
Text-to-image search using CLIP embeddings
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
image_embeddings_path: Path,
|
|
||||||
text_model_name: str = DEFAULT_TEXT_MODEL,
|
|
||||||
embedding_dim: int = 4096
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Initialize the Text-to-Image searcher
|
|
||||||
|
|
||||||
Args:
|
|
||||||
image_embeddings_path: Path to the pickle file containing image embeddings
|
|
||||||
text_model_name: Name of the CLIP text model to use
|
|
||||||
embedding_dim: Dimension of the embeddings (should match image embeddings)
|
|
||||||
"""
|
|
||||||
self.image_embeddings_path = Path(image_embeddings_path)
|
|
||||||
self.embedding_dim = embedding_dim
|
|
||||||
|
|
||||||
# Load image embeddings
|
|
||||||
logger.info(f"Loading image embeddings from {self.image_embeddings_path}")
|
|
||||||
self.image_data = self._load_image_embeddings()
|
|
||||||
|
|
||||||
# Initialize text embedder
|
|
||||||
logger.info(f"Initializing text embedder: {text_model_name}")
|
|
||||||
self.text_embedder = CLIPTextEmbedder(
|
|
||||||
model_name=text_model_name,
|
|
||||||
output_dim=embedding_dim
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info(f"TextToImageSearcher ready with {len(self.image_data)} images")
|
|
||||||
|
|
||||||
def _load_image_embeddings(self) -> Dict[str, np.ndarray]:
|
|
||||||
"""Load image embeddings from pickle file"""
|
|
||||||
try:
|
|
||||||
with open(self.image_embeddings_path, "rb") as f:
|
|
||||||
data = pickle.load(f)
|
|
||||||
|
|
||||||
if not isinstance(data, dict):
|
|
||||||
raise ValueError("Image embeddings file should contain a dictionary")
|
|
||||||
|
|
||||||
# Verify embedding dimensions
|
|
||||||
sample_key = next(iter(data))
|
|
||||||
sample_emb = data[sample_key]
|
|
||||||
if sample_emb.shape[-1] != self.embedding_dim:
|
|
||||||
raise ValueError(f"Expected {self.embedding_dim}D embeddings, got {sample_emb.shape[-1]}D")
|
|
||||||
|
|
||||||
logger.info(f"Loaded {len(data)} image embeddings of dimension {sample_emb.shape[-1]}")
|
|
||||||
return data
|
|
||||||
|
|
||||||
except FileNotFoundError:
|
|
||||||
raise FileNotFoundError(f"Image embeddings file not found: {self.image_embeddings_path}")
|
|
||||||
except Exception as e:
|
|
||||||
raise RuntimeError(f"Failed to load image embeddings: {e}")
|
|
||||||
|
|
||||||
def search(
|
|
||||||
self,
|
|
||||||
query_text: str,
|
|
||||||
top_k: int = DEFAULT_TOP_K,
|
|
||||||
return_scores: bool = False
|
|
||||||
) -> List[Dict]:
|
|
||||||
"""
|
|
||||||
Search for images matching the query text
|
|
||||||
|
|
||||||
Args:
|
|
||||||
query_text: Text description to search for
|
|
||||||
top_k: Number of top results to return
|
|
||||||
return_scores: Whether to include similarity scores
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of dictionaries with 'path' and optionally 'score' keys
|
|
||||||
"""
|
|
||||||
logger.info(f"Searching for: '{query_text}'")
|
|
||||||
|
|
||||||
# Embed query text
|
|
||||||
query_embedding = self.text_embedder.embed_single_text(query_text)
|
|
||||||
|
|
||||||
# Prepare image data for comparison
|
|
||||||
image_paths = list(self.image_data.keys())
|
|
||||||
image_embeddings = np.stack(list(self.image_data.values())) # (N, D)
|
|
||||||
|
|
||||||
# Compute cosine similarities (both embeddings are normalized)
|
|
||||||
similarities = np.dot(image_embeddings, query_embedding)
|
|
||||||
|
|
||||||
# Get top-k indices
|
|
||||||
top_k_indices = np.argsort(-similarities)[:top_k]
|
|
||||||
|
|
||||||
# Prepare results
|
|
||||||
results = []
|
|
||||||
for idx in top_k_indices:
|
|
||||||
result = {'path': image_paths[idx]}
|
|
||||||
if return_scores:
|
|
||||||
result['score'] = float(similarities[idx])
|
|
||||||
results.append(result)
|
|
||||||
|
|
||||||
return results
|
|
||||||
|
|
||||||
def batch_search(
|
|
||||||
self,
|
|
||||||
query_texts: List[str],
|
|
||||||
top_k: int = DEFAULT_TOP_K
|
|
||||||
) -> Dict[str, List[Dict]]:
|
|
||||||
"""
|
|
||||||
Search for multiple queries at once
|
|
||||||
|
|
||||||
Args:
|
|
||||||
query_texts: List of text descriptions
|
|
||||||
top_k: Number of results per query
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Dictionary mapping query text to list of results
|
|
||||||
"""
|
|
||||||
logger.info(f"Batch searching {len(query_texts)} queries")
|
|
||||||
|
|
||||||
results = {}
|
|
||||||
for query in query_texts:
|
|
||||||
results[query] = self.search(query, top_k=top_k, return_scores=True)
|
|
||||||
|
|
||||||
return results
|
|
||||||
|
|
||||||
def get_search_stats(self, query_text: str) -> Dict:
|
|
||||||
"""
|
|
||||||
Get statistics about search results
|
|
||||||
|
|
||||||
Args:
|
|
||||||
query_text: Query to analyze
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Dictionary with similarity statistics
|
|
||||||
"""
|
|
||||||
query_embedding = self.text_embedder.embed_single_text(query_text)
|
|
||||||
image_embeddings = np.stack(list(self.image_data.values()))
|
|
||||||
similarities = np.dot(image_embeddings, query_embedding)
|
|
||||||
|
|
||||||
return {
|
|
||||||
'mean_similarity': float(np.mean(similarities)),
|
|
||||||
'max_similarity': float(np.max(similarities)),
|
|
||||||
'min_similarity': float(np.min(similarities)),
|
|
||||||
'std_similarity': float(np.std(similarities)),
|
|
||||||
'median_similarity': float(np.median(similarities))
|
|
||||||
}
|
|
||||||
|
|
||||||
def get_memory_usage(self):
|
|
||||||
"""Get memory usage information"""
|
|
||||||
return self.text_embedder.get_memory_usage()
|
|
||||||
|
|
||||||
def clear_cache(self):
|
|
||||||
"""Clear text embedding cache"""
|
|
||||||
self.text_embedder.clear_cache()
|
|
||||||
|
|
||||||
|
|
||||||
def print_results(results: List[Dict], query: str, show_scores: bool = True):
|
|
||||||
"""Pretty print search results"""
|
|
||||||
print(f"\nTop {len(results)} results for: '{query}'")
|
|
||||||
print("-" * 80)
|
|
||||||
|
|
||||||
for i, result in enumerate(results, 1):
|
|
||||||
path = result['path']
|
|
||||||
if show_scores and 'score' in result:
|
|
||||||
score = result['score']
|
|
||||||
print(f"{i:2d}. {path:<60} score={score:>7.4f}")
|
|
||||||
else:
|
|
||||||
print(f"{i:2d}. {path}")
|
|
||||||
|
|
||||||
print("-" * 80)
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
parser = argparse.ArgumentParser(description="Text-to-Image Search using CLIP embeddings")
|
|
||||||
|
|
||||||
# Input/Output paths
|
|
||||||
parser.add_argument(
|
|
||||||
"--image-embeddings",
|
|
||||||
type=str,
|
|
||||||
default=str(DEFAULT_IMG_EMB_PATH),
|
|
||||||
help="Path to image embeddings pickle file"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Query options
|
|
||||||
parser.add_argument(
|
|
||||||
"--query",
|
|
||||||
type=str,
|
|
||||||
help="Single text query to search for"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--queries-file",
|
|
||||||
type=str,
|
|
||||||
help="File with one query per line"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--interactive",
|
|
||||||
action="store_true",
|
|
||||||
help="Interactive mode for multiple queries"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Search parameters
|
|
||||||
parser.add_argument(
|
|
||||||
"--top-k",
|
|
||||||
type=int,
|
|
||||||
default=DEFAULT_TOP_K,
|
|
||||||
help="Number of top results to return"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--text-model",
|
|
||||||
type=str,
|
|
||||||
default=DEFAULT_TEXT_MODEL,
|
|
||||||
help="CLIP text model to use"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--embedding-dim",
|
|
||||||
type=int,
|
|
||||||
default=4096,
|
|
||||||
help="Embedding dimension"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Output options
|
|
||||||
parser.add_argument(
|
|
||||||
"--output",
|
|
||||||
type=str,
|
|
||||||
help="Save results to JSON file"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--stats",
|
|
||||||
action="store_true",
|
|
||||||
help="Show similarity statistics"
|
|
||||||
)
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
# Initialize searcher
|
|
||||||
try:
|
|
||||||
searcher = TextToImageSearcher(
|
|
||||||
image_embeddings_path=args.image_embeddings,
|
|
||||||
text_model_name=args.text_model,
|
|
||||||
embedding_dim=args.embedding_dim
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Failed to initialize searcher: {e}")
|
|
||||||
return
|
|
||||||
|
|
||||||
# Show memory usage
|
|
||||||
mem = searcher.get_memory_usage()
|
|
||||||
logger.info(f"GPU Memory - Allocated: {mem['allocated']:.2f}GB, Reserved: {mem['reserved']:.2f}GB")
|
|
||||||
|
|
||||||
# Handle different query modes
|
|
||||||
if args.query:
|
|
||||||
# Single query
|
|
||||||
results = searcher.search(args.query, top_k=args.top_k, return_scores=True)
|
|
||||||
print_results(results, args.query)
|
|
||||||
|
|
||||||
if args.stats:
|
|
||||||
stats = searcher.get_search_stats(args.query)
|
|
||||||
print(f"\nSimilarity Statistics:")
|
|
||||||
for key, value in stats.items():
|
|
||||||
print(f" {key}: {value:.4f}")
|
|
||||||
|
|
||||||
elif args.queries_file:
|
|
||||||
# Batch queries from file
|
|
||||||
try:
|
|
||||||
with open(args.queries_file, 'r', encoding='utf-8') as f:
|
|
||||||
queries = [line.strip() for line in f if line.strip()]
|
|
||||||
|
|
||||||
batch_results = searcher.batch_search(queries, top_k=args.top_k)
|
|
||||||
|
|
||||||
for query, results in batch_results.items():
|
|
||||||
print_results(results, query)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Failed to process queries file: {e}")
|
|
||||||
return
|
|
||||||
|
|
||||||
elif args.interactive:
|
|
||||||
# Interactive mode
|
|
||||||
print("\nInteractive Text-to-Image Search")
|
|
||||||
print("Enter text queries (or 'quit' to exit):")
|
|
||||||
print("-" * 40)
|
|
||||||
|
|
||||||
while True:
|
|
||||||
try:
|
|
||||||
query = input("\nQuery: ").strip()
|
|
||||||
if query.lower() in ['quit', 'exit', 'q']:
|
|
||||||
break
|
|
||||||
|
|
||||||
if not query:
|
|
||||||
continue
|
|
||||||
|
|
||||||
results = searcher.search(query, top_k=args.top_k, return_scores=True)
|
|
||||||
print_results(results, query)
|
|
||||||
|
|
||||||
if args.stats:
|
|
||||||
stats = searcher.get_search_stats(query)
|
|
||||||
print(f"\nStats: mean={stats['mean_similarity']:.3f}, "
|
|
||||||
f"max={stats['max_similarity']:.3f}, "
|
|
||||||
f"std={stats['std_similarity']:.3f}")
|
|
||||||
|
|
||||||
except KeyboardInterrupt:
|
|
||||||
print("\nExiting...")
|
|
||||||
break
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error processing query: {e}")
|
|
||||||
|
|
||||||
else:
|
|
||||||
# Demo mode with sample queries
|
|
||||||
demo_queries = [
|
|
||||||
"a cat sleeping on a bed",
|
|
||||||
]
|
|
||||||
|
|
||||||
print("Demo mode: Running sample queries...")
|
|
||||||
batch_results = searcher.batch_search(demo_queries, top_k=5)
|
|
||||||
|
|
||||||
for query, results in batch_results.items():
|
|
||||||
print_results(results, query)
|
|
||||||
|
|
||||||
# Save results if requested
|
|
||||||
if args.output and 'batch_results' in locals():
|
|
||||||
import json
|
|
||||||
|
|
||||||
# Convert results to JSON-serializable format
|
|
||||||
json_results = {}
|
|
||||||
for query, results in batch_results.items():
|
|
||||||
json_results[query] = results # Results are already dict format
|
|
||||||
|
|
||||||
try:
|
|
||||||
with open(args.output, 'w', encoding='utf-8') as f:
|
|
||||||
json.dump(json_results, f, indent=2, ensure_ascii=False)
|
|
||||||
logger.info(f"Results saved to {args.output}")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Failed to save results: {e}")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
||||||
Loading…
x
Reference in New Issue
Block a user