Initial Commit
This commit is contained in:
commit
d173a94399
2
embed/__init__.py
Normal file
2
embed/__init__.py
Normal file
@ -0,0 +1,2 @@
|
||||
class ImageEmbedder:
|
||||
pass
|
||||
258
embed/image_embedder.py
Normal file
258
embed/image_embedder.py
Normal file
@ -0,0 +1,258 @@
|
||||
import pickle
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from PIL import Image, UnidentifiedImageError
|
||||
from transformers import CLIPImageProcessor, CLIPVisionModel
|
||||
from tqdm import tqdm
|
||||
from atomicwrites import atomic_write
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Device & defaults
|
||||
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
DEFAULT_VISION_MODEL = "openai/clip-vit-large-patch14-336"
|
||||
PROJECTOR_WEIGHTS_PATH = (
|
||||
Path(__file__).parent.parent
|
||||
/ "models"
|
||||
/ "llava-v1.5-mlp2x-336px-vicuna-7b"
|
||||
/ "mm_projector.bin"
|
||||
)
|
||||
OUTPUT_DIM = 4096
|
||||
DEFAULT_BATCH_SIZE = 32
|
||||
|
||||
# Default output path
|
||||
EMBED_DIR = Path(__file__).parent.parent / "embeddings"
|
||||
IMAGE_EMB_PATH = EMBED_DIR / "image_embeddings.pkl"
|
||||
|
||||
|
||||
class ImageEmbedder:
|
||||
"""
|
||||
Frozen CLIP vision encoder + pretrained 2-layer MLP projector.
|
||||
Uses PATCH-LEVEL features only (mean of patch embeddings) following
|
||||
LLaVA's configuration: mm_vision_select_feature="patch".
|
||||
Provides single-image and folder-wide embedding, with options to
|
||||
print only or append to a persistent pickle store via atomic writes.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vision_model_name: str = DEFAULT_VISION_MODEL,
|
||||
projector_weights_path: Union[str, Path] = PROJECTOR_WEIGHTS_PATH,
|
||||
output_dim: int = OUTPUT_DIM,
|
||||
):
|
||||
# 1) Setup CLIP processor + model
|
||||
self.processor = CLIPImageProcessor.from_pretrained(vision_model_name)
|
||||
self.vision_model = (
|
||||
CLIPVisionModel
|
||||
.from_pretrained(vision_model_name, output_hidden_states=True)
|
||||
.to(DEVICE)
|
||||
.eval()
|
||||
)
|
||||
for p in self.vision_model.parameters():
|
||||
p.requires_grad = False
|
||||
|
||||
# 2) Build MLP projector - PATCH ONLY (1024 input, not 2048)
|
||||
# Following LLaVA config: mm_vision_select_feature="patch", mm_hidden_size=1024
|
||||
hidden_dim = self.vision_model.config.hidden_size # 1024, not 1024*2
|
||||
self.projector = nn.Sequential(
|
||||
nn.Linear(hidden_dim, output_dim), # 1024 -> 4096
|
||||
nn.GELU(),
|
||||
nn.Linear(output_dim, output_dim), # 4096 -> 4096
|
||||
).to(DEVICE)
|
||||
|
||||
# 3) Load projector weights
|
||||
ckpt = torch.load(projector_weights_path, map_location=DEVICE)
|
||||
state = ckpt.get("projector_state_dict", ckpt)
|
||||
# Strip prefix
|
||||
clean_state = {
|
||||
(k[len("model.mm_projector."):] if k.startswith("model.mm_projector.") else k): v
|
||||
for k, v in state.items()
|
||||
}
|
||||
# Load with warning on mismatch
|
||||
missing, unexpected = self.projector.load_state_dict(clean_state, strict=False)
|
||||
if missing or unexpected:
|
||||
logger.info(f"Loaded projector weights with missing={missing}, unexpected={unexpected}")
|
||||
|
||||
def image_to_embedding(self, image: Image.Image) -> torch.Tensor:
|
||||
"""
|
||||
Embed a single PIL.Image → (output_dim,) L2-normalized vector.
|
||||
Uses PATCH-LEVEL features only (mean of patches, no CLS token).
|
||||
"""
|
||||
try:
|
||||
inputs = self.processor(images=image, return_tensors="pt")
|
||||
inputs = {k: v.to(DEVICE) for k, v in inputs.items()}
|
||||
except Exception as e:
|
||||
logger.exception(f"Preprocessing failed: {e}")
|
||||
raise
|
||||
|
||||
with torch.no_grad():
|
||||
last_hidden = self.vision_model(**inputs).last_hidden_state
|
||||
|
||||
# PATCH-ONLY: Use mean of patch embeddings, drop CLS token
|
||||
# last_hidden shape: (1, 1+num_patches, 1024)
|
||||
# [:, 0, :] = CLS token (skip)
|
||||
# [:, 1:, :] = patch tokens (use mean)
|
||||
patch_mean = last_hidden[:, 1:, :].mean(dim=1) # (1, 1024)
|
||||
combo = nn.functional.normalize(patch_mean, p=2, dim=1) # L2 normalize
|
||||
|
||||
proj = self.projector(combo) # (1, 4096)
|
||||
emb = nn.functional.normalize(proj, p=2, dim=1) # L2 normalize output
|
||||
return emb.squeeze(0) # → (OUTPUT_DIM,)
|
||||
|
||||
def embed_folder(
|
||||
self,
|
||||
folder_path: Union[str, Path],
|
||||
save_path: Union[str, Path] = IMAGE_EMB_PATH,
|
||||
extensions: List[str] = None,
|
||||
batch_size: int = DEFAULT_BATCH_SIZE,
|
||||
) -> Dict[str, np.ndarray]:
|
||||
"""
|
||||
Embed all images in a folder (recursively), skipping already-embedded
|
||||
and corrupted files, saving after each batch.
|
||||
Uses PATCH-LEVEL features only.
|
||||
"""
|
||||
if extensions is None:
|
||||
extensions = [".jpg", ".jpeg", ".png", ".webp", ".bmp", ".tiff", ".gif"]
|
||||
|
||||
folder = Path(folder_path)
|
||||
save_path = Path(save_path)
|
||||
save_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Load existing embeddings
|
||||
try:
|
||||
with open(save_path, "rb") as fp:
|
||||
mapping = pickle.load(fp)
|
||||
except FileNotFoundError:
|
||||
mapping = {}
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not load embeddings: {e}")
|
||||
mapping = {}
|
||||
|
||||
# Collect files to embed
|
||||
all_files = [f for f in folder.rglob("*") if f.suffix.lower() in extensions]
|
||||
to_process = [f for f in all_files if str(f) not in mapping]
|
||||
if not to_process:
|
||||
logger.info("No new images to embed.")
|
||||
return mapping
|
||||
|
||||
logger.info(f"Embedding {len(to_process)} images (batch size {batch_size})")
|
||||
for i in tqdm(range(0, len(to_process), batch_size), desc="Images"):
|
||||
batch = to_process[i: i + batch_size]
|
||||
imgs, paths = [], []
|
||||
for f in batch:
|
||||
try:
|
||||
imgs.append(Image.open(f).convert("RGB"))
|
||||
paths.append(f)
|
||||
except UnidentifiedImageError:
|
||||
logger.warning(f"Skipping corrupted: {f}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Error loading {f}: {e}")
|
||||
|
||||
if not imgs:
|
||||
continue
|
||||
|
||||
# Embed batch
|
||||
inputs = self.processor(images=imgs, return_tensors="pt")
|
||||
inputs = {k: v.to(DEVICE) for k, v in inputs.items()}
|
||||
with torch.no_grad():
|
||||
last_hidden = self.vision_model(**inputs).last_hidden_state
|
||||
|
||||
# PATCH-ONLY: Use mean of patches, drop CLS token
|
||||
patch_means = last_hidden[:, 1:, :].mean(dim=1) # (batch_size, 1024)
|
||||
combo = nn.functional.normalize(patch_means, p=2, dim=1)
|
||||
proj = self.projector(combo) # (batch_size, 4096)
|
||||
embs = nn.functional.normalize(proj, p=2, dim=1).cpu().detach().numpy()
|
||||
|
||||
# Update and save
|
||||
for pth, emb in zip(paths, embs):
|
||||
mapping[str(pth)] = emb
|
||||
with atomic_write(save_path, mode="wb", overwrite=True) as fp:
|
||||
pickle.dump(mapping, fp, protocol=pickle.HIGHEST_PROTOCOL)
|
||||
|
||||
logger.info(f"Done. Total embeddings: {len(mapping)}")
|
||||
return mapping
|
||||
|
||||
def embed_image(
|
||||
self,
|
||||
image_path: Union[str, Path],
|
||||
save_path: Union[str, Path] = IMAGE_EMB_PATH,
|
||||
no_save: bool = False,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Embed one image using PATCH-LEVEL features only:
|
||||
- prints the vector
|
||||
- unless no_save=True, also appends to the pickle at save_path
|
||||
Returns the torch.Tensor.
|
||||
"""
|
||||
img_path = Path(image_path)
|
||||
try:
|
||||
img = Image.open(img_path).convert("RGB")
|
||||
vec = self.image_to_embedding(img)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed embedding {img_path}: {e}")
|
||||
raise
|
||||
|
||||
# Print out the vector
|
||||
vec_np = vec.cpu().detach().numpy()
|
||||
print(f"{img_path} → shape={vec_np.shape}")
|
||||
print(vec_np)
|
||||
|
||||
# Append to pickle unless disabled
|
||||
if not no_save:
|
||||
save_path = Path(save_path)
|
||||
save_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
try:
|
||||
with open(save_path, "rb") as fp:
|
||||
mapping = pickle.load(fp)
|
||||
except (FileNotFoundError, pickle.PickleError):
|
||||
mapping = {}
|
||||
|
||||
mapping[str(img_path)] = vec_np
|
||||
with atomic_write(save_path, overwrite=True) as fp:
|
||||
fp.write(pickle.dumps(mapping, protocol=pickle.HIGHEST_PROTOCOL))
|
||||
logger.info(f"Saved embedding for {img_path} to {save_path}")
|
||||
|
||||
return vec
|
||||
|
||||
|
||||
def main():
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(__doc__)
|
||||
parser.add_argument("--image", type=str, help="Path to a single image to embed")
|
||||
parser.add_argument("--folder", type=str, help="Path to a folder of images")
|
||||
parser.add_argument("--out", type=str, default=str(IMAGE_EMB_PATH),
|
||||
help="Pickle to save embeddings (default: embeddings/image_embeddings.pkl)")
|
||||
parser.add_argument("--batch-size", type=int, default=DEFAULT_BATCH_SIZE,
|
||||
help="Batch size for folder embedding")
|
||||
parser.add_argument("--no-save", action="store_true",
|
||||
help="When embedding a single image, do not save to the pickle")
|
||||
args = parser.parse_args()
|
||||
|
||||
embedder = ImageEmbedder()
|
||||
|
||||
if args.image:
|
||||
embedder.embed_image(
|
||||
image_path=args.image,
|
||||
save_path=args.out,
|
||||
no_save=args.no_save
|
||||
)
|
||||
elif args.folder:
|
||||
embedder.embed_folder(
|
||||
folder_path=args.folder,
|
||||
save_path=args.out,
|
||||
batch_size=args.batch_size
|
||||
)
|
||||
else:
|
||||
parser.error("Please specify --image or --folder")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
138
embed/text_embedder.py
Normal file
138
embed/text_embedder.py
Normal file
@ -0,0 +1,138 @@
|
||||
import pickle
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Dict, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import numpy as np
|
||||
from sentence_transformers import SentenceTransformer
|
||||
from cachetools import LRUCache
|
||||
from tqdm import tqdm
|
||||
from atomicwrites import atomic_write
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Device & defaults
|
||||
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
# Use a model that's good for semantic similarity
|
||||
DEFAULT_MODEL = "sentence-transformers/clip-ViT-B-32-multilingual-v1" # CLIP-based, good for images
|
||||
OUTPUT_DIM = 4096
|
||||
ARTICLE_CACHE_SIZE = 1024
|
||||
|
||||
# Persistence paths
|
||||
EMBEDDINGS_DIR = Path(__file__).parent.parent / "embeddings"
|
||||
TEXT_EMB_PATH = EMBEDDINGS_DIR / "text_embeddings.pkl"
|
||||
|
||||
|
||||
class SimpleSentenceEmbedder:
|
||||
"""
|
||||
Simple text embedder using sentence-transformers with CLIP backbone.
|
||||
This should give much better text-image alignment.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str = DEFAULT_MODEL,
|
||||
output_dim: int = OUTPUT_DIM,
|
||||
):
|
||||
# Load sentence transformer model
|
||||
self.model = SentenceTransformer(model_name, device=DEVICE)
|
||||
|
||||
# Get the model's output dimension
|
||||
model_dim = self.model.get_sentence_embedding_dimension()
|
||||
|
||||
# Simple projector to match your 4096-D image embeddings
|
||||
self.projector = nn.Sequential(
|
||||
nn.Linear(model_dim, output_dim),
|
||||
nn.ReLU(),
|
||||
nn.Linear(output_dim, output_dim),
|
||||
).to(DEVICE)
|
||||
|
||||
# Better initialization
|
||||
for layer in self.projector:
|
||||
if isinstance(layer, nn.Linear):
|
||||
nn.init.xavier_uniform_(layer.weight)
|
||||
nn.init.zeros_(layer.bias)
|
||||
|
||||
self.article_cache = LRUCache(maxsize=ARTICLE_CACHE_SIZE)
|
||||
|
||||
# Ensure embeddings dir
|
||||
EMBEDDINGS_DIR.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
logger.info(f"Initialized SentenceEmbedder with {model_name} ({model_dim}D -> {output_dim}D)")
|
||||
|
||||
def text_to_embedding(self, text: str) -> torch.Tensor:
|
||||
"""
|
||||
Embed text using sentence transformer + simple projector
|
||||
"""
|
||||
if text in self.article_cache:
|
||||
return self.article_cache[text]
|
||||
|
||||
# Get sentence embedding
|
||||
with torch.no_grad():
|
||||
sentence_emb = self.model.encode([text], convert_to_tensor=True, device=DEVICE)
|
||||
sentence_emb = nn.functional.normalize(sentence_emb, p=2, dim=1)
|
||||
|
||||
# Project to target dimension
|
||||
projected = self.projector(sentence_emb)
|
||||
projected = nn.functional.normalize(projected, p=2, dim=1)
|
||||
|
||||
result = projected.squeeze(0).cpu()
|
||||
|
||||
self.article_cache[text] = result
|
||||
return result
|
||||
|
||||
def embed_file(
|
||||
self,
|
||||
text_file: Union[str, Path],
|
||||
save_path: Union[str, Path] = TEXT_EMB_PATH
|
||||
) -> Dict[str, np.ndarray]:
|
||||
"""
|
||||
Read newline-separated articles and embed them
|
||||
"""
|
||||
save_path = Path(save_path)
|
||||
save_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
try:
|
||||
with open(save_path, "rb") as fp:
|
||||
mapping: Dict[str, np.ndarray] = pickle.load(fp)
|
||||
except FileNotFoundError:
|
||||
mapping = {}
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed loading {save_path}: {e}")
|
||||
mapping = {}
|
||||
|
||||
try:
|
||||
with open(text_file, "r", encoding="utf-8") as fp:
|
||||
lines = [ln.strip() for ln in fp if ln.strip()]
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Error reading {text_file}: {e}")
|
||||
|
||||
for ln in tqdm(lines, desc="Embedding articles"):
|
||||
if ln not in mapping:
|
||||
try:
|
||||
vec = self.text_to_embedding(ln)
|
||||
mapping[ln] = vec.cpu().numpy()
|
||||
except Exception as e:
|
||||
logger.error(f"Embedding failed for article: {e}")
|
||||
|
||||
with atomic_write(save_path, mode='wb', overwrite=True) as f:
|
||||
pickle.dump(mapping, f, protocol=pickle.HIGHEST_PROTOCOL)
|
||||
|
||||
return mapping
|
||||
|
||||
def get_memory_usage(self):
|
||||
"""Get current GPU memory usage if available."""
|
||||
if torch.cuda.is_available():
|
||||
return {
|
||||
'allocated': torch.cuda.memory_allocated() / 1024 ** 3, # GB
|
||||
'reserved': torch.cuda.memory_reserved() / 1024 ** 3, # GB
|
||||
}
|
||||
return {'allocated': 0, 'reserved': 0}
|
||||
|
||||
|
||||
# Alias for compatibility
|
||||
TextEmbedder = SimpleSentenceEmbedder
|
||||
BIN
embeddings/chunk_texts.pkl
Normal file
BIN
embeddings/chunk_texts.pkl
Normal file
Binary file not shown.
BIN
embeddings/image_embeddings.pkl
Normal file
BIN
embeddings/image_embeddings.pkl
Normal file
Binary file not shown.
103
examples/search.py
Normal file
103
examples/search.py
Normal file
@ -0,0 +1,103 @@
|
||||
import numpy as np
|
||||
import pickle
|
||||
from pathlib import Path
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from sentence_transformers import SentenceTransformer
|
||||
|
||||
IMG_EMB_PKL = Path("E:/GiteaRepo/text2img/embeddings/image_embeddings.pkl")
|
||||
TOP_K = 5
|
||||
query_text = """orange cat"""
|
||||
|
||||
|
||||
def get_text_embedding(text, model, projector, device):
|
||||
"""Get text embedding using sentence transformer + projector"""
|
||||
with torch.no_grad():
|
||||
# Get sentence embedding
|
||||
sentence_emb = model.encode([text], convert_to_tensor=True, device=device)
|
||||
sentence_emb = nn.functional.normalize(sentence_emb, p=2, dim=1)
|
||||
|
||||
# Project to 4096D
|
||||
projected = projector(sentence_emb)
|
||||
projected = nn.functional.normalize(projected, p=2, dim=1)
|
||||
|
||||
return projected.squeeze(0).cpu().numpy()
|
||||
|
||||
|
||||
def main():
|
||||
print(f"Loading image embeddings from {IMG_EMB_PKL}")
|
||||
|
||||
# 1) Load image embeddings
|
||||
try:
|
||||
with open(IMG_EMB_PKL, "rb") as f:
|
||||
data = pickle.load(f)
|
||||
except FileNotFoundError:
|
||||
print(f"Error: Could not find {IMG_EMB_PKL}")
|
||||
return
|
||||
|
||||
paths = np.array(list(data.keys()))
|
||||
vecs = np.stack(list(data.values())) # shape (N, 4096)
|
||||
print(f"Loaded {len(paths)} image embeddings with shape {vecs.shape}")
|
||||
|
||||
# 2) Initialize text embedding model
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
print(f"Using device: {device}")
|
||||
|
||||
# Use CLIP-based sentence transformer for better text-image alignment
|
||||
model_name = "sentence-transformers/clip-ViT-B-32-multilingual-v1"
|
||||
print(f"Loading text model: {model_name}")
|
||||
|
||||
try:
|
||||
model = SentenceTransformer(model_name, device=device)
|
||||
model_dim = model.get_sentence_embedding_dimension()
|
||||
print(f"Text model dimension: {model_dim}")
|
||||
|
||||
# Simple projector to match 4096D image embeddings
|
||||
projector = nn.Sequential(
|
||||
nn.Linear(model_dim, 4096),
|
||||
nn.ReLU(),
|
||||
nn.Linear(4096, 4096),
|
||||
).to(device)
|
||||
|
||||
# Better initialization
|
||||
for layer in projector:
|
||||
if isinstance(layer, nn.Linear):
|
||||
nn.init.xavier_uniform_(layer.weight)
|
||||
nn.init.zeros_(layer.bias)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error loading text model: {e}")
|
||||
return
|
||||
|
||||
# 3) Embed the query
|
||||
print(f"Embedding query: '{query_text}'")
|
||||
try:
|
||||
q = get_text_embedding(query_text, model, projector, device)
|
||||
print(f"Query embedding shape: {q.shape}")
|
||||
print(f"Query embedding norm: {np.linalg.norm(q):.4f}")
|
||||
except Exception as e:
|
||||
print(f"Error embedding query: {e}")
|
||||
return
|
||||
|
||||
# 4) Cosine similarity
|
||||
print("Computing similarities...")
|
||||
scores = vecs @ q # Dot product since both are normalized
|
||||
top_k_idx = np.argsort(-scores)[:TOP_K]
|
||||
|
||||
# 5) Display results
|
||||
print(f"\nTop-{TOP_K} images for: {query_text!r}\n")
|
||||
print("-" * 80)
|
||||
for rank, idx in enumerate(top_k_idx, 1):
|
||||
print(f"{rank:>2}. {paths[idx]:<60} score={scores[idx]:>7.4f}")
|
||||
print("-" * 80)
|
||||
|
||||
# Show stats
|
||||
print(f"\nStatistics:")
|
||||
print(f"Mean similarity: {scores.mean():.4f}")
|
||||
print(f"Max similarity: {scores.max():.4f}")
|
||||
print(f"Min similarity: {scores.min():.4f}")
|
||||
print(f"Std similarity: {scores.std():.4f}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
261
examples/similarity_finder.py
Normal file
261
examples/similarity_finder.py
Normal file
@ -0,0 +1,261 @@
|
||||
import pickle
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Tuple, Union, Optional
|
||||
import argparse
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Default paths
|
||||
EMBED_DIR = Path(__file__).parent.parent / "embeddings"
|
||||
IMAGE_EMB_PATH = EMBED_DIR / "image_embeddings.pkl"
|
||||
|
||||
|
||||
class ImageSimilarityFinder:
|
||||
"""
|
||||
Find similar images using precomputed embeddings from image_embeddings.pkl.
|
||||
Uses cosine similarity to rank images by similarity.
|
||||
"""
|
||||
|
||||
def __init__(self, embeddings_path: Union[str, Path] = IMAGE_EMB_PATH):
|
||||
"""
|
||||
Initialize the similarity finder with precomputed embeddings.
|
||||
|
||||
Args:
|
||||
embeddings_path: Path to the pickle file containing image embeddings
|
||||
"""
|
||||
self.embeddings_path = Path(embeddings_path)
|
||||
self.embeddings = self._load_embeddings()
|
||||
self.image_paths = list(self.embeddings.keys())
|
||||
self.embedding_matrix = np.array(list(self.embeddings.values()))
|
||||
|
||||
logger.info(f"Loaded {len(self.embeddings)} image embeddings")
|
||||
|
||||
def _load_embeddings(self) -> Dict[str, np.ndarray]:
|
||||
"""Load embeddings from pickle file."""
|
||||
try:
|
||||
with open(self.embeddings_path, "rb") as fp:
|
||||
embeddings = pickle.load(fp)
|
||||
return embeddings
|
||||
except FileNotFoundError:
|
||||
raise FileNotFoundError(f"Embeddings file not found: {self.embeddings_path}")
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to load embeddings: {e}")
|
||||
|
||||
def _cosine_similarity(self, vec1: np.ndarray, vec2: np.ndarray) -> float:
|
||||
"""
|
||||
Calculate cosine similarity between two vectors.
|
||||
Since embeddings are already L2 normalized, this is just dot product.
|
||||
"""
|
||||
return np.dot(vec1, vec2)
|
||||
|
||||
def find_similar_by_path(
|
||||
self,
|
||||
query_image_path: Union[str, Path],
|
||||
k: int = 5,
|
||||
exclude_self: bool = True
|
||||
) -> List[Tuple[str, float]]:
|
||||
"""
|
||||
Find k most similar images to a query image by its path.
|
||||
|
||||
Args:
|
||||
query_image_path: Path to the query image
|
||||
k: Number of similar images to return
|
||||
exclude_self: Whether to exclude the query image itself from results
|
||||
|
||||
Returns:
|
||||
List of tuples (image_path, similarity_score) sorted by similarity (highest first)
|
||||
"""
|
||||
query_path = str(Path(query_image_path))
|
||||
|
||||
if query_path not in self.embeddings:
|
||||
raise ValueError(f"Query image not found in embeddings: {query_path}")
|
||||
|
||||
query_embedding = self.embeddings[query_path]
|
||||
return self._find_similar_by_embedding(query_embedding, k, exclude_path=query_path if exclude_self else None)
|
||||
|
||||
def find_similar_by_embedding(
|
||||
self,
|
||||
query_embedding: np.ndarray,
|
||||
k: int = 5
|
||||
) -> List[Tuple[str, float]]:
|
||||
"""
|
||||
Find k most similar images to a query embedding.
|
||||
|
||||
Args:
|
||||
query_embedding: Query image embedding vector
|
||||
k: Number of similar images to return
|
||||
|
||||
Returns:
|
||||
List of tuples (image_path, similarity_score) sorted by similarity (highest first)
|
||||
"""
|
||||
return self._find_similar_by_embedding(query_embedding, k)
|
||||
|
||||
def _find_similar_by_embedding(
|
||||
self,
|
||||
query_embedding: np.ndarray,
|
||||
k: int,
|
||||
exclude_path: Optional[str] = None
|
||||
) -> List[Tuple[str, float]]:
|
||||
"""
|
||||
Internal method to find similar images by embedding.
|
||||
|
||||
Args:
|
||||
query_embedding: Query embedding vector
|
||||
k: Number of results to return
|
||||
exclude_path: Path to exclude from results (typically the query image itself)
|
||||
|
||||
Returns:
|
||||
List of tuples (image_path, similarity_score) sorted by similarity
|
||||
"""
|
||||
# Ensure query embedding is normalized (should already be from ImageEmbedder)
|
||||
query_embedding = query_embedding / np.linalg.norm(query_embedding)
|
||||
|
||||
# Compute similarities with all embeddings
|
||||
similarities = np.dot(self.embedding_matrix, query_embedding)
|
||||
|
||||
# Get indices sorted by similarity (descending)
|
||||
sorted_indices = np.argsort(similarities)[::-1]
|
||||
|
||||
# Collect results, excluding the query image if specified
|
||||
results = []
|
||||
for idx in sorted_indices:
|
||||
image_path = self.image_paths[idx]
|
||||
similarity = similarities[idx]
|
||||
|
||||
# Skip if this is the path to exclude
|
||||
if exclude_path and image_path == exclude_path:
|
||||
continue
|
||||
|
||||
results.append((image_path, float(similarity)))
|
||||
|
||||
# Stop when we have k results
|
||||
if len(results) >= k:
|
||||
break
|
||||
|
||||
return results
|
||||
|
||||
def find_similar_images_batch(
|
||||
self,
|
||||
query_paths: List[Union[str, Path]],
|
||||
k: int = 5
|
||||
) -> Dict[str, List[Tuple[str, float]]]:
|
||||
"""
|
||||
Find similar images for multiple query images at once.
|
||||
|
||||
Args:
|
||||
query_paths: List of paths to query images
|
||||
k: Number of similar images to return for each query
|
||||
|
||||
Returns:
|
||||
Dictionary mapping query_path -> list of (similar_path, similarity_score)
|
||||
"""
|
||||
results = {}
|
||||
for query_path in query_paths:
|
||||
try:
|
||||
results[str(query_path)] = self.find_similar_by_path(query_path, k)
|
||||
except ValueError as e:
|
||||
logger.warning(f"Skipping {query_path}: {e}")
|
||||
results[str(query_path)] = []
|
||||
|
||||
return results
|
||||
|
||||
def get_image_info(self) -> Dict[str, any]:
|
||||
"""Get information about loaded embeddings."""
|
||||
return {
|
||||
"total_images": len(self.embeddings),
|
||||
"embedding_dimension": self.embedding_matrix.shape[1] if len(self.embedding_matrix) > 0 else 0,
|
||||
"sample_paths": self.image_paths[:5] if self.image_paths else []
|
||||
}
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Find similar images using precomputed embeddings"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--query",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the query image"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--embeddings",
|
||||
type=str,
|
||||
default=str(IMAGE_EMB_PATH),
|
||||
help="Path to the embeddings pickle file"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--k",
|
||||
type=int,
|
||||
default=5,
|
||||
help="Number of similar images to return (default: 5)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--include-self",
|
||||
action="store_true",
|
||||
help="Include the query image in results"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--info",
|
||||
action="store_true",
|
||||
help="Show information about loaded embeddings"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Initialize similarity finder
|
||||
try:
|
||||
finder = ImageSimilarityFinder(args.embeddings)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize similarity finder: {e}")
|
||||
logger.error(f"Make sure you're providing the path to the embeddings pickle file, not an image file")
|
||||
logger.error(f"Expected: --embeddings /path/to/image_embeddings.pkl")
|
||||
return
|
||||
|
||||
# Show info if requested
|
||||
if args.info:
|
||||
info = finder.get_image_info()
|
||||
print(f"Embeddings Info:")
|
||||
print(f" Total images: {info['total_images']}")
|
||||
print(f" Embedding dimension: {info['embedding_dimension']}")
|
||||
print(f" Sample paths: {info['sample_paths']}")
|
||||
print()
|
||||
|
||||
# Find similar images
|
||||
try:
|
||||
similar_images = finder.find_similar_by_path(
|
||||
args.query,
|
||||
k=args.k,
|
||||
exclude_self=not args.include_self
|
||||
)
|
||||
|
||||
print(f"Top {len(similar_images)} similar images to '{args.query}':")
|
||||
print("-" * 80)
|
||||
|
||||
for i, (image_path, similarity) in enumerate(similar_images, 1):
|
||||
print(f"{i:2d}. {Path(image_path).name}")
|
||||
print(f" Path: {image_path}")
|
||||
print(f" Similarity: {similarity:.4f}")
|
||||
print()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to find similar images: {e}")
|
||||
|
||||
# Check if query image exists in embeddings
|
||||
if "not found in embeddings" in str(e):
|
||||
logger.error("Available images in embeddings:")
|
||||
info = finder.get_image_info()
|
||||
for i, path in enumerate(info['sample_paths']):
|
||||
logger.error(f" {i + 1}. {path}")
|
||||
if len(finder.image_paths) > 5:
|
||||
logger.error(f" ... and {len(finder.image_paths) - 5} more images")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
12
examples/starter.py
Normal file
12
examples/starter.py
Normal file
@ -0,0 +1,12 @@
|
||||
from similarity_finder import ImageSimilarityFinder
|
||||
|
||||
# Initialize with the embeddings pickle file (not an image file)
|
||||
finder = ImageSimilarityFinder("E:/GiteaRepo/text2img/embeddings/image_embeddings.pkl")
|
||||
|
||||
# Query with the actual image path
|
||||
query_image = "data/coco/val2017/1140002154.jpg"
|
||||
similar_images = finder.find_similar_by_path(query_image, k=5)
|
||||
|
||||
print(f"Top 5 similar images to {query_image}:")
|
||||
for i, (image_path, similarity) in enumerate(similar_images, 1):
|
||||
print(f"{i}. {image_path} (similarity: {similarity:.4f})")
|
||||
41
models/llava-v1.5-mlp2x-336px-vicuna-7b/README.md
Normal file
41
models/llava-v1.5-mlp2x-336px-vicuna-7b/README.md
Normal file
@ -0,0 +1,41 @@
|
||||
---
|
||||
inference: false
|
||||
---
|
||||
|
||||
<br>
|
||||
<br>
|
||||
|
||||
# LLaVA Model Card
|
||||
|
||||
This is a pretrained checkpoint, you can use it to instruct tune your multimodal models.
|
||||
|
||||
Check out the instructions [here](https://github.com/haotian-liu/LLaVA/blob/main/README.md#visual-instruction-tuning)
|
||||
|
||||
## Model details
|
||||
|
||||
**Model type:**
|
||||
LLaVA is an open-source chatbot trained by fine-tuning LLaMA/Vicuna on GPT-generated multimodal instruction-following data.
|
||||
It is an auto-regressive language model, based on the transformer architecture.
|
||||
|
||||
**Model date:**
|
||||
LLaVA-v1.5-MLP2x-336px-Pretrain-Vicuna-7B-v1.5 was trained in September 2023.
|
||||
|
||||
**Paper or resources for more information:**
|
||||
https://llava-vl.github.io/
|
||||
|
||||
## License
|
||||
Llama 2 is licensed under the LLAMA 2 Community License,
|
||||
Copyright (c) Meta Platforms, Inc. All Rights Reserved.
|
||||
|
||||
**Where to send questions or comments about the model:**
|
||||
https://github.com/haotian-liu/LLaVA/issues
|
||||
|
||||
## Intended use
|
||||
**Primary intended uses:**
|
||||
The primary use of LLaVA is research on large multimodal models and chatbots.
|
||||
|
||||
**Primary intended users:**
|
||||
The primary intended users of the model are researchers and hobbyists in computer vision, natural language processing, machine learning, and artificial intelligence.
|
||||
|
||||
## Training dataset
|
||||
- 558K filtered image-text pairs from LAION/CC/SBU, captioned by BLIP.
|
||||
38
models/llava-v1.5-mlp2x-336px-vicuna-7b/config.json
Normal file
38
models/llava-v1.5-mlp2x-336px-vicuna-7b/config.json
Normal file
@ -0,0 +1,38 @@
|
||||
{
|
||||
"_name_or_path": "./checkpoints/vicuna-7b-v1-5",
|
||||
"architectures": [
|
||||
"LlamaForCausalLM"
|
||||
],
|
||||
"bos_token_id": 1,
|
||||
"eos_token_id": 2,
|
||||
"hidden_act": "silu",
|
||||
"hidden_size": 4096,
|
||||
"image_aspect_ratio": "square",
|
||||
"initializer_range": 0.02,
|
||||
"intermediate_size": 11008,
|
||||
"max_position_embeddings": 4096,
|
||||
"mm_hidden_size": 1024,
|
||||
"mm_patch_merge_type": "flat",
|
||||
"mm_projector_type": "mlp2x_gelu",
|
||||
"mm_use_im_patch_token": false,
|
||||
"mm_use_im_start_end": false,
|
||||
"mm_vision_select_feature": "patch",
|
||||
"mm_vision_select_layer": -2,
|
||||
"mm_vision_tower": "openai/clip-vit-large-patch14-336",
|
||||
"model_type": "llava",
|
||||
"num_attention_heads": 32,
|
||||
"num_hidden_layers": 32,
|
||||
"num_key_value_heads": 32,
|
||||
"pad_token_id": 0,
|
||||
"pretraining_tp": 1,
|
||||
"rms_norm_eps": 1e-05,
|
||||
"rope_scaling": null,
|
||||
"tie_word_embeddings": false,
|
||||
"torch_dtype": "float16",
|
||||
"transformers_version": "4.31.0",
|
||||
"tune_mm_mlp_adapter": true,
|
||||
"tune_mm_vision_resampler": false,
|
||||
"use_cache": true,
|
||||
"use_mm_proj": true,
|
||||
"vocab_size": 32000
|
||||
}
|
||||
35
models/llava-v1.5-mlp2x-336px-vicuna-7b/gitattributes
Normal file
35
models/llava-v1.5-mlp2x-336px-vicuna-7b/gitattributes
Normal file
@ -0,0 +1,35 @@
|
||||
*.7z filter=lfs diff=lfs merge=lfs -text
|
||||
*.arrow filter=lfs diff=lfs merge=lfs -text
|
||||
*.bin filter=lfs diff=lfs merge=lfs -text
|
||||
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
||||
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
||||
*.ftz filter=lfs diff=lfs merge=lfs -text
|
||||
*.gz filter=lfs diff=lfs merge=lfs -text
|
||||
*.h5 filter=lfs diff=lfs merge=lfs -text
|
||||
*.joblib filter=lfs diff=lfs merge=lfs -text
|
||||
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
||||
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
||||
*.model filter=lfs diff=lfs merge=lfs -text
|
||||
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
||||
*.npy filter=lfs diff=lfs merge=lfs -text
|
||||
*.npz filter=lfs diff=lfs merge=lfs -text
|
||||
*.onnx filter=lfs diff=lfs merge=lfs -text
|
||||
*.ot filter=lfs diff=lfs merge=lfs -text
|
||||
*.parquet filter=lfs diff=lfs merge=lfs -text
|
||||
*.pb filter=lfs diff=lfs merge=lfs -text
|
||||
*.pickle filter=lfs diff=lfs merge=lfs -text
|
||||
*.pkl filter=lfs diff=lfs merge=lfs -text
|
||||
*.pt filter=lfs diff=lfs merge=lfs -text
|
||||
*.pth filter=lfs diff=lfs merge=lfs -text
|
||||
*.rar filter=lfs diff=lfs merge=lfs -text
|
||||
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
||||
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
||||
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
||||
*.tar filter=lfs diff=lfs merge=lfs -text
|
||||
*.tflite filter=lfs diff=lfs merge=lfs -text
|
||||
*.tgz filter=lfs diff=lfs merge=lfs -text
|
||||
*.wasm filter=lfs diff=lfs merge=lfs -text
|
||||
*.xz filter=lfs diff=lfs merge=lfs -text
|
||||
*.zip filter=lfs diff=lfs merge=lfs -text
|
||||
*.zst filter=lfs diff=lfs merge=lfs -text
|
||||
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
||||
BIN
models/llava-v1.5-mlp2x-336px-vicuna-7b/mm_projector.bin
Normal file
BIN
models/llava-v1.5-mlp2x-336px-vicuna-7b/mm_projector.bin
Normal file
Binary file not shown.
13111
models/llava-v1.5-mlp2x-336px-vicuna-7b/trainer_state.json
Normal file
13111
models/llava-v1.5-mlp2x-336px-vicuna-7b/trainer_state.json
Normal file
File diff suppressed because it is too large
Load Diff
53
readme.md
Normal file
53
readme.md
Normal file
@ -0,0 +1,53 @@
|
||||
# Text2Img
|
||||
|
||||
A rough implementation of generating image embeddings through methodologies introduced in LLaVA
|
||||
|
||||
### Structure
|
||||
We derived the image embeddings by using a CLIP encoder and mapping it with the pretrained LLaVA’s projection weights layer
|
||||
|
||||
### Prerequisites
|
||||
1. install requirements.txt
|
||||
2. Make sure you have downloaded [pytorch_model-00003-of-00003.bin](https://huggingface.co/liuhaotian/LLaVA-13b-delta-v1-1/blob/main/pytorch_model-00003-of-00003.bin")
|
||||
3. For example image data, I use [2017 Val images 5K/1GB](http://images.cocodataset.org/zips/val2017.zip) and [2017 Train/Val annotations 241MB](http://images.cocodataset.org/annotations/annotations_trainval2017.zip)
|
||||
|
||||
### Usage
|
||||
|
||||
For image_embedder.py:
|
||||
|
||||
1. Embed a single image (Print Only):
|
||||
`python -m embed.image_embedder
|
||||
--image "C:\path\img.jpg"
|
||||
--no-save
|
||||
`
|
||||
|
||||
2. Embed a single image (Save to File):
|
||||
`python -m embed.image_embedder
|
||||
--image "C:\path\to\image.jpg"
|
||||
--out "C:\project\embeddings\image_embeddings.pkl"
|
||||
`
|
||||
|
||||
3. Embed a single folder of images:
|
||||
`python -m embed.image_embedder
|
||||
--folder "C:\path\to\images"
|
||||
--out "C:\project\embeddings\image_embeddings.pkl"
|
||||
--batch-size 32
|
||||
`
|
||||
|
||||
For text_embedder.py:
|
||||
1. Embed a Single Article (Print Only):
|
||||
`python -m embed.text_embedder
|
||||
--text "This is my single-article input string."
|
||||
`
|
||||
|
||||
2. Embed a Single Article (Save to File):
|
||||
`python -m embed.text_embedder
|
||||
--text "This is my single-article input string."
|
||||
--out "C:\project\embeddings\text_embeddings.pkl"
|
||||
`
|
||||
|
||||
3. Embed multiple articles from a file (one per line):
|
||||
`python -m embed.text_embedder
|
||||
--file "C:\path\to\articles.txt"
|
||||
--out "C:\project\embeddings\text_embeddings.pkl"
|
||||
--batch-size 8
|
||||
`
|
||||
Loading…
x
Reference in New Issue
Block a user