tattoo_search_engine / embeddings.py
Onur Çopur
add dinov3 and dinov2 with registers
0647d62
from abc import ABC, abstractmethod
from typing import Dict, Any, List
import torch
import torch.nn.functional as F
from PIL import Image
import logging
logger = logging.getLogger(__name__)
class EmbeddingModel(ABC):
"""Abstract base class for embedding models."""
def __init__(self, device: torch.device):
self.device = device
self.model = None
self.preprocess = None
@abstractmethod
def load_model(self) -> None:
"""Load the embedding model and preprocessing."""
pass
@abstractmethod
def encode_image(self, image: Image.Image) -> torch.Tensor:
"""Encode an image into feature vector."""
pass
def encode_image_patches(self, image: Image.Image) -> torch.Tensor:
"""Encode an image into patch-level features. Override in subclasses that support it."""
raise NotImplementedError("Patch-level encoding not implemented for this model")
def compute_patch_attention(self, query_patches: torch.Tensor, candidate_patches: torch.Tensor) -> torch.Tensor:
"""Compute attention weights between query and candidate patches."""
# query_patches: [num_query_patches, feature_dim]
# candidate_patches: [num_candidate_patches, feature_dim]
# Normalize patches
query_patches = F.normalize(query_patches, p=2, dim=1)
candidate_patches = F.normalize(candidate_patches, p=2, dim=1)
# Compute attention matrix: [num_query_patches, num_candidate_patches]
attention_matrix = torch.mm(query_patches, candidate_patches.T)
return attention_matrix
@abstractmethod
def get_model_name(self) -> str:
"""Return the model name."""
pass
def compute_similarity(self, query_features: torch.Tensor, candidate_features: torch.Tensor) -> float:
"""Compute similarity between query and candidate features."""
return torch.mm(query_features, candidate_features.T).item()
class CLIPEmbedding(EmbeddingModel):
"""CLIP-based embedding model."""
def __init__(self, device: torch.device, model_name: str = "ViT-B-32"):
super().__init__(device)
self.model_name = model_name
self.tokenizer = None
self.load_model()
def load_model(self) -> None:
"""Load CLIP model and preprocessing."""
try:
import open_clip
logger.info(f"Loading CLIP model: {self.model_name}")
self.model, _, self.preprocess = open_clip.create_model_and_transforms(
self.model_name, pretrained="openai"
)
self.model.to(self.device)
self.tokenizer = open_clip.get_tokenizer(self.model_name)
logger.info(f"CLIP model {self.model_name} loaded successfully")
except Exception as e:
logger.error(f"Failed to load CLIP model: {e}")
raise
def encode_image(self, image: Image.Image) -> torch.Tensor:
"""Encode image using CLIP."""
try:
image_input = self.preprocess(image).unsqueeze(0).to(self.device)
with torch.no_grad():
features = self.model.encode_image(image_input)
features = F.normalize(features, p=2, dim=1)
return features
except Exception as e:
logger.error(f"Failed to encode image with CLIP: {e}")
raise
def encode_image_patches(self, image: Image.Image) -> torch.Tensor:
"""Encode image patches using CLIP vision transformer."""
try:
image_input = self.preprocess(image).unsqueeze(0).to(self.device)
with torch.no_grad():
# Get patch features from CLIP vision transformer
vision_model = self.model.visual
# Pass through patch embedding and positional encoding
x = vision_model.conv1(image_input) # shape = [*, width, grid, grid]
x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
# Add class token and positional embeddings
x = torch.cat([vision_model.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1)
x = x + vision_model.positional_embedding.to(x.dtype)
# Apply layer norm
x = vision_model.ln_pre(x)
x = x.permute(1, 0, 2) # NLD -> LND
# Pass through transformer blocks
for block in vision_model.transformer.resblocks:
x = block(x)
x = x.permute(1, 0, 2) # LND -> NLD
# Remove class token to get only patch features
patch_features = x[:, 1:, :] # [1, num_patches, feature_dim]
patch_features = vision_model.ln_post(patch_features)
# Apply projection if it exists
if vision_model.proj is not None:
patch_features = patch_features @ vision_model.proj
# Normalize patch features
patch_features = F.normalize(patch_features, p=2, dim=-1)
return patch_features.squeeze(0) # [num_patches, feature_dim]
except Exception as e:
logger.error(f"Failed to encode image patches with CLIP: {e}")
raise
def get_model_name(self) -> str:
return f"CLIP-{self.model_name}"
class DINOv2Embedding(EmbeddingModel):
"""DINOv2-based embedding model."""
def __init__(self, device: torch.device, model_name: str = "dinov2_vitb14"):
super().__init__(device)
self.model_name = model_name
self.load_model()
def load_model(self) -> None:
"""Load DINOv2 model and preprocessing."""
try:
import torch.hub
from torchvision import transforms
logger.info(f"Loading DINOv2 model: {self.model_name}")
# Load DINOv2 model from torch hub
self.model = torch.hub.load('facebookresearch/dinov2', self.model_name)
self.model.to(self.device)
self.model.eval()
# DINOv2 preprocessing
self.preprocess = transforms.Compose([
transforms.Resize(256, interpolation=transforms.InterpolationMode.BICUBIC),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
logger.info(f"DINOv2 model {self.model_name} loaded successfully")
except Exception as e:
logger.error(f"Failed to load DINOv2 model: {e}")
raise
def encode_image(self, image: Image.Image) -> torch.Tensor:
"""Encode image using DINOv2."""
try:
image_input = self.preprocess(image).unsqueeze(0).to(self.device)
with torch.no_grad():
features = self.model(image_input)
features = F.normalize(features, p=2, dim=1)
return features
except Exception as e:
logger.error(f"Failed to encode image with DINOv2: {e}")
raise
def encode_image_patches(self, image: Image.Image) -> torch.Tensor:
"""Encode image patches using DINOv2."""
try:
image_input = self.preprocess(image).unsqueeze(0).to(self.device)
with torch.no_grad():
# Get patch features from DINOv2
# DINOv2 forward_features returns dict with 'x_norm_patchtokens' containing patch features
features_dict = self.model.forward_features(image_input)
patch_features = features_dict['x_norm_patchtokens'] # [1, num_patches, feature_dim]
# Normalize patch features
patch_features = F.normalize(patch_features, p=2, dim=-1)
return patch_features.squeeze(0) # [num_patches, feature_dim]
except Exception as e:
logger.error(f"Failed to encode image patches with DINOv2: {e}")
raise
def get_model_name(self) -> str:
return f"DINOv2-{self.model_name}"
class DINOv2WithRegistersEmbedding(EmbeddingModel):
"""DINOv2 with register tokens - improved feature maps and attention."""
def __init__(self, device: torch.device, model_name: str = "facebook/dinov2-with-registers-base"):
super().__init__(device)
self.model_name = model_name
self.processor = None
self.load_model()
def load_model(self) -> None:
"""Load DINOv2 with registers model and preprocessing."""
try:
from transformers import Dinov2WithRegistersModel, AutoImageProcessor
logger.info(f"Loading DINOv2 with registers model: {self.model_name}")
self.model = Dinov2WithRegistersModel.from_pretrained(self.model_name)
self.model.to(self.device)
self.model.eval()
self.processor = AutoImageProcessor.from_pretrained(self.model_name)
logger.info(f"DINOv2 with registers model {self.model_name} loaded successfully")
except Exception as e:
logger.error(f"Failed to load DINOv2 with registers model: {e}")
raise
def encode_image(self, image: Image.Image) -> torch.Tensor:
"""Encode image using DINOv2 with registers."""
try:
inputs = self.processor(images=image, return_tensors="pt")
inputs = {k: v.to(self.device) for k, v in inputs.items()}
with torch.no_grad():
outputs = self.model(**inputs)
# Use pooler_output for global representation, fallback to mean pooling
if hasattr(outputs, 'pooler_output') and outputs.pooler_output is not None:
features = outputs.pooler_output
else:
# Mean pooling over spatial dimensions
features = outputs.last_hidden_state.mean(dim=1)
features = F.normalize(features, p=2, dim=1)
return features
except Exception as e:
logger.error(f"Failed to encode image with DINOv2 with registers: {e}")
raise
def encode_image_patches(self, image: Image.Image) -> torch.Tensor:
"""Encode image patches using DINOv2 with registers."""
try:
inputs = self.processor(images=image, return_tensors="pt")
inputs = {k: v.to(self.device) for k, v in inputs.items()}
with torch.no_grad():
outputs = self.model(**inputs)
# Token sequence structure: [CLS] + 4 register tokens + 256 patch tokens = 261 total
# We want only the spatial patch tokens (positions 5 to 260)
num_register_tokens = 4
expected_patches = (224 // 14) ** 2 # 256 for base model with 224x224 input, patch size 14
# Skip CLS token (position 0) and register tokens (positions 1-4)
start_idx = 1 + num_register_tokens # Position 5
end_idx = start_idx + expected_patches # Position 261
patch_features = outputs.last_hidden_state[:, start_idx:end_idx, :] # [1, 256, feature_dim]
# Normalize patch features
patch_features = F.normalize(patch_features, p=2, dim=-1)
return patch_features.squeeze(0) # [num_patches, feature_dim]
except Exception as e:
logger.error(f"Failed to encode image patches with DINOv2 with registers: {e}")
raise
def get_model_name(self) -> str:
return f"DINOv2-WithRegisters-{self.model_name.split('/')[-1]}"
def get_attention_maps(self, image: Image.Image) -> torch.Tensor:
"""
Extract native attention maps from DINOv2 with registers.
Returns:
Attention tensor with shape (num_layers, num_heads, num_tokens, num_tokens)
where num_tokens includes [CLS] + patches + registers
"""
try:
inputs = self.processor(images=image, return_tensors="pt")
inputs = {k: v.to(self.device) for k, v in inputs.items()}
with torch.no_grad():
outputs = self.model(**inputs, output_attentions=True)
# outputs.attentions is a tuple of attention tensors, one per layer
# Each has shape: (batch_size, num_heads, sequence_length, sequence_length)
# Stack all layer attentions
attention_stack = torch.stack(outputs.attentions) # (num_layers, batch_size, num_heads, seq_len, seq_len)
attention_stack = attention_stack.squeeze(1) # Remove batch dimension -> (num_layers, num_heads, seq_len, seq_len)
return attention_stack
except Exception as e:
logger.error(f"Failed to extract attention maps: {e}")
raise
def compute_cross_attention(self, query_image: Image.Image, candidate_image: Image.Image) -> torch.Tensor:
"""
Compute cross-attention between query and candidate images using patch features.
This uses the extracted patch embeddings to compute attention from query to candidate,
similar to the native attention mechanism but across two images.
Returns:
Cross-attention matrix with shape (query_patches, candidate_patches)
"""
try:
# Get patch features for both images
query_patches = self.encode_image_patches(query_image) # (num_query_patches, feature_dim)
candidate_patches = self.encode_image_patches(candidate_image) # (num_candidate_patches, feature_dim)
# Compute attention-style similarity (softmax over candidate dimension)
# attention[i,j] = how much query patch i attends to candidate patch j
attention_logits = torch.mm(query_patches, candidate_patches.T) # (query_patches, candidate_patches)
# Apply softmax to get attention distribution for each query patch
cross_attention = F.softmax(attention_logits, dim=1)
return cross_attention
except Exception as e:
logger.error(f"Failed to compute cross-attention: {e}")
raise
def supports_native_attention(self) -> bool:
"""Check if this model supports native attention extraction."""
return True
class DINOv3Embedding(EmbeddingModel):
"""DINOv3-based embedding model from HuggingFace transformers."""
def __init__(self, device: torch.device, model_name: str = "facebook/dinov3-vits16-pretrain-lvd1689m"):
super().__init__(device)
self.model_name = model_name
self.processor = None
self.load_model()
def load_model(self) -> None:
"""Load DINOv3 model and preprocessing."""
try:
from transformers import AutoModel, AutoImageProcessor
logger.info(f"Loading DINOv3 model: {self.model_name}")
self.model = AutoModel.from_pretrained(self.model_name)
self.model.to(self.device)
self.model.eval()
self.processor = AutoImageProcessor.from_pretrained(self.model_name)
logger.info(f"DINOv3 model {self.model_name} loaded successfully")
except Exception as e:
logger.error(f"Failed to load DINOv3 model: {e}")
raise
def encode_image(self, image: Image.Image) -> torch.Tensor:
"""Encode image using DINOv3."""
try:
inputs = self.processor(images=image, return_tensors="pt")
inputs = {k: v.to(self.device) for k, v in inputs.items()}
with torch.no_grad():
outputs = self.model(**inputs)
# Use pooler_output (CLS token) for global representation
if hasattr(outputs, 'pooler_output') and outputs.pooler_output is not None:
features = outputs.pooler_output
else:
# Fallback to mean pooling over patch embeddings
features = outputs.last_hidden_state[:, 1:, :].mean(dim=1)
features = F.normalize(features, p=2, dim=1)
return features
except Exception as e:
logger.error(f"Failed to encode image with DINOv3: {e}")
raise
def encode_image_patches(self, image: Image.Image) -> torch.Tensor:
"""Encode image patches using DINOv3."""
try:
inputs = self.processor(images=image, return_tensors="pt")
inputs = {k: v.to(self.device) for k, v in inputs.items()}
with torch.no_grad():
outputs = self.model(**inputs)
# DINOv3 outputs: [CLS] + register tokens + patch tokens
# We want only the patch tokens (skip CLS at position 0 and register tokens)
# For DINOv3-ViTS16, it has 4 register tokens
num_register_tokens = 4
patch_features = outputs.last_hidden_state[:, 1 + num_register_tokens:, :]
# Normalize patch features
patch_features = F.normalize(patch_features, p=2, dim=-1)
return patch_features.squeeze(0) # [num_patches, feature_dim]
except Exception as e:
logger.error(f"Failed to encode image patches with DINOv3: {e}")
raise
def get_model_name(self) -> str:
return f"DINOv3-{self.model_name.split('/')[-1]}"
def supports_native_attention(self) -> bool:
"""Check if this model supports native attention extraction."""
return True
def get_attention_maps(self, image: Image.Image) -> torch.Tensor:
"""
Extract native attention maps from DINOv3.
Returns:
Attention tensor with shape (num_layers, num_heads, num_tokens, num_tokens)
"""
try:
inputs = self.processor(images=image, return_tensors="pt")
inputs = {k: v.to(self.device) for k, v in inputs.items()}
with torch.no_grad():
outputs = self.model(**inputs, output_attentions=True)
# Stack all layer attentions
attention_stack = torch.stack(outputs.attentions)
attention_stack = attention_stack.squeeze(1) # Remove batch dimension
return attention_stack
except Exception as e:
logger.error(f"Failed to extract attention maps: {e}")
raise
def compute_cross_attention(self, query_image: Image.Image, candidate_image: Image.Image) -> torch.Tensor:
"""
Compute cross-attention between query and candidate images using patch features.
Returns:
Cross-attention matrix with shape (query_patches, candidate_patches)
"""
try:
query_patches = self.encode_image_patches(query_image)
candidate_patches = self.encode_image_patches(candidate_image)
# Compute attention-style similarity
attention_logits = torch.mm(query_patches, candidate_patches.T)
# Apply softmax to get attention distribution
cross_attention = F.softmax(attention_logits, dim=1)
return cross_attention
except Exception as e:
logger.error(f"Failed to compute cross-attention: {e}")
raise
class SigLIPEmbedding(EmbeddingModel):
"""SigLIP-based embedding model."""
def __init__(self, device: torch.device, model_name: str = "google/siglip-base-patch16-224"):
super().__init__(device)
self.model_name = model_name
self.processor = None
self.load_model()
def load_model(self) -> None:
"""Load SigLIP model and preprocessing."""
try:
# Check for required dependencies
try:
import sentencepiece
except ImportError:
raise ImportError(
"SentencePiece is required for SigLIP. Install with: pip install sentencepiece"
)
from transformers import SiglipVisionModel, SiglipProcessor
logger.info(f"Loading SigLIP model: {self.model_name}")
self.model = SiglipVisionModel.from_pretrained(self.model_name)
self.model.to(self.device)
self.model.eval()
self.processor = SiglipProcessor.from_pretrained(self.model_name)
logger.info(f"SigLIP model {self.model_name} loaded successfully")
except Exception as e:
logger.error(f"Failed to load SigLIP model: {e}")
raise
def encode_image(self, image: Image.Image) -> torch.Tensor:
"""Encode image using SigLIP."""
try:
inputs = self.processor(images=image, return_tensors="pt")
inputs = {k: v.to(self.device) for k, v in inputs.items()}
with torch.no_grad():
outputs = self.model(**inputs)
features = outputs.last_hidden_state.mean(dim=1) # Global average pooling
features = F.normalize(features, p=2, dim=1)
return features
except Exception as e:
logger.error(f"Failed to encode image with SigLIP: {e}")
raise
def encode_image_patches(self, image: Image.Image) -> torch.Tensor:
"""Encode image patches using SigLIP."""
try:
inputs = self.processor(images=image, return_tensors="pt")
inputs = {k: v.to(self.device) for k, v in inputs.items()}
with torch.no_grad():
outputs = self.model(**inputs)
# last_hidden_state contains patch features: [1, num_patches, feature_dim]
patch_features = outputs.last_hidden_state
# Normalize patch features
patch_features = F.normalize(patch_features, p=2, dim=-1)
return patch_features.squeeze(0) # [num_patches, feature_dim]
except Exception as e:
logger.error(f"Failed to encode image patches with SigLIP: {e}")
raise
def get_model_name(self) -> str:
return f"SigLIP-{self.model_name.split('/')[-1]}"
class EmbeddingModelFactory:
"""Factory class for creating embedding models."""
AVAILABLE_MODELS = {
"clip": CLIPEmbedding,
"dinov2": DINOv2Embedding,
"dinov2_registers": DINOv2WithRegistersEmbedding,
"dinov3": DINOv3Embedding,
"siglip": SigLIPEmbedding,
}
@classmethod
def create_model(cls, model_type: str, device: torch.device, **kwargs) -> EmbeddingModel:
"""Create an embedding model instance.
Args:
model_type: Type of model ('clip', 'dinov2', 'dinov2_registers', 'dinov3', 'siglip')
device: PyTorch device
**kwargs: Additional arguments for specific models
Returns:
EmbeddingModel instance
"""
if model_type.lower() not in cls.AVAILABLE_MODELS:
raise ValueError(f"Unknown model type: {model_type}. Available: {list(cls.AVAILABLE_MODELS.keys())}")
model_class = cls.AVAILABLE_MODELS[model_type.lower()]
try:
return model_class(device, **kwargs)
except Exception as e:
logger.error(f"Failed to create {model_type} model: {e}")
# Fallback to CLIP if the requested model fails
if model_type.lower() != 'clip':
logger.info("Falling back to CLIP model")
return cls.AVAILABLE_MODELS['clip'](device, **kwargs)
else:
raise
@classmethod
def get_available_models(cls) -> List[str]:
"""Get list of available model types."""
return list(cls.AVAILABLE_MODELS.keys())
def get_default_model_configs() -> Dict[str, Dict[str, Any]]:
"""Get default configurations for each model type."""
return {
"clip": {
"model_name": "ViT-B-32",
"description": "OpenAI CLIP model - good general purpose vision-language model"
},
"dinov2": {
"model_name": "dinov2_vitb14",
"description": "Meta DINOv2 - self-supervised vision transformer, good for visual features"
},
"dinov2_registers": {
"model_name": "facebook/dinov2-with-registers-base",
"description": "Meta DINOv2 with register tokens - improved feature maps and attention"
},
"dinov3": {
"model_name": "facebook/dinov3-vits16-pretrain-lvd1689m",
"description": "Meta DINOv3 - vision foundation model with high-quality dense features"
},
"siglip": {
"model_name": "google/siglip-base-patch16-224",
"description": "Google SigLIP - improved CLIP-like model with better training"
}
}