from abc import ABC, abstractmethod from typing import Dict, Any, List import torch import torch.nn.functional as F from PIL import Image import logging logger = logging.getLogger(__name__) class EmbeddingModel(ABC): """Abstract base class for embedding models.""" def __init__(self, device: torch.device): self.device = device self.model = None self.preprocess = None @abstractmethod def load_model(self) -> None: """Load the embedding model and preprocessing.""" pass @abstractmethod def encode_image(self, image: Image.Image) -> torch.Tensor: """Encode an image into feature vector.""" pass def encode_image_patches(self, image: Image.Image) -> torch.Tensor: """Encode an image into patch-level features. Override in subclasses that support it.""" raise NotImplementedError("Patch-level encoding not implemented for this model") def compute_patch_attention(self, query_patches: torch.Tensor, candidate_patches: torch.Tensor) -> torch.Tensor: """Compute attention weights between query and candidate patches.""" # query_patches: [num_query_patches, feature_dim] # candidate_patches: [num_candidate_patches, feature_dim] # Normalize patches query_patches = F.normalize(query_patches, p=2, dim=1) candidate_patches = F.normalize(candidate_patches, p=2, dim=1) # Compute attention matrix: [num_query_patches, num_candidate_patches] attention_matrix = torch.mm(query_patches, candidate_patches.T) return attention_matrix @abstractmethod def get_model_name(self) -> str: """Return the model name.""" pass def compute_similarity(self, query_features: torch.Tensor, candidate_features: torch.Tensor) -> float: """Compute similarity between query and candidate features.""" return torch.mm(query_features, candidate_features.T).item() class CLIPEmbedding(EmbeddingModel): """CLIP-based embedding model.""" def __init__(self, device: torch.device, model_name: str = "ViT-B-32"): super().__init__(device) self.model_name = model_name self.tokenizer = None self.load_model() def load_model(self) -> None: """Load CLIP model and preprocessing.""" try: import open_clip logger.info(f"Loading CLIP model: {self.model_name}") self.model, _, self.preprocess = open_clip.create_model_and_transforms( self.model_name, pretrained="openai" ) self.model.to(self.device) self.tokenizer = open_clip.get_tokenizer(self.model_name) logger.info(f"CLIP model {self.model_name} loaded successfully") except Exception as e: logger.error(f"Failed to load CLIP model: {e}") raise def encode_image(self, image: Image.Image) -> torch.Tensor: """Encode image using CLIP.""" try: image_input = self.preprocess(image).unsqueeze(0).to(self.device) with torch.no_grad(): features = self.model.encode_image(image_input) features = F.normalize(features, p=2, dim=1) return features except Exception as e: logger.error(f"Failed to encode image with CLIP: {e}") raise def encode_image_patches(self, image: Image.Image) -> torch.Tensor: """Encode image patches using CLIP vision transformer.""" try: image_input = self.preprocess(image).unsqueeze(0).to(self.device) with torch.no_grad(): # Get patch features from CLIP vision transformer vision_model = self.model.visual # Pass through patch embedding and positional encoding x = vision_model.conv1(image_input) # shape = [*, width, grid, grid] x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] # Add class token and positional embeddings x = torch.cat([vision_model.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) x = x + vision_model.positional_embedding.to(x.dtype) # Apply layer norm x = vision_model.ln_pre(x) x = x.permute(1, 0, 2) # NLD -> LND # Pass through transformer blocks for block in vision_model.transformer.resblocks: x = block(x) x = x.permute(1, 0, 2) # LND -> NLD # Remove class token to get only patch features patch_features = x[:, 1:, :] # [1, num_patches, feature_dim] patch_features = vision_model.ln_post(patch_features) # Apply projection if it exists if vision_model.proj is not None: patch_features = patch_features @ vision_model.proj # Normalize patch features patch_features = F.normalize(patch_features, p=2, dim=-1) return patch_features.squeeze(0) # [num_patches, feature_dim] except Exception as e: logger.error(f"Failed to encode image patches with CLIP: {e}") raise def get_model_name(self) -> str: return f"CLIP-{self.model_name}" class DINOv2Embedding(EmbeddingModel): """DINOv2-based embedding model.""" def __init__(self, device: torch.device, model_name: str = "dinov2_vitb14"): super().__init__(device) self.model_name = model_name self.load_model() def load_model(self) -> None: """Load DINOv2 model and preprocessing.""" try: import torch.hub from torchvision import transforms logger.info(f"Loading DINOv2 model: {self.model_name}") # Load DINOv2 model from torch hub self.model = torch.hub.load('facebookresearch/dinov2', self.model_name) self.model.to(self.device) self.model.eval() # DINOv2 preprocessing self.preprocess = transforms.Compose([ transforms.Resize(256, interpolation=transforms.InterpolationMode.BICUBIC), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) logger.info(f"DINOv2 model {self.model_name} loaded successfully") except Exception as e: logger.error(f"Failed to load DINOv2 model: {e}") raise def encode_image(self, image: Image.Image) -> torch.Tensor: """Encode image using DINOv2.""" try: image_input = self.preprocess(image).unsqueeze(0).to(self.device) with torch.no_grad(): features = self.model(image_input) features = F.normalize(features, p=2, dim=1) return features except Exception as e: logger.error(f"Failed to encode image with DINOv2: {e}") raise def encode_image_patches(self, image: Image.Image) -> torch.Tensor: """Encode image patches using DINOv2.""" try: image_input = self.preprocess(image).unsqueeze(0).to(self.device) with torch.no_grad(): # Get patch features from DINOv2 # DINOv2 forward_features returns dict with 'x_norm_patchtokens' containing patch features features_dict = self.model.forward_features(image_input) patch_features = features_dict['x_norm_patchtokens'] # [1, num_patches, feature_dim] # Normalize patch features patch_features = F.normalize(patch_features, p=2, dim=-1) return patch_features.squeeze(0) # [num_patches, feature_dim] except Exception as e: logger.error(f"Failed to encode image patches with DINOv2: {e}") raise def get_model_name(self) -> str: return f"DINOv2-{self.model_name}" class DINOv2WithRegistersEmbedding(EmbeddingModel): """DINOv2 with register tokens - improved feature maps and attention.""" def __init__(self, device: torch.device, model_name: str = "facebook/dinov2-with-registers-base"): super().__init__(device) self.model_name = model_name self.processor = None self.load_model() def load_model(self) -> None: """Load DINOv2 with registers model and preprocessing.""" try: from transformers import Dinov2WithRegistersModel, AutoImageProcessor logger.info(f"Loading DINOv2 with registers model: {self.model_name}") self.model = Dinov2WithRegistersModel.from_pretrained(self.model_name) self.model.to(self.device) self.model.eval() self.processor = AutoImageProcessor.from_pretrained(self.model_name) logger.info(f"DINOv2 with registers model {self.model_name} loaded successfully") except Exception as e: logger.error(f"Failed to load DINOv2 with registers model: {e}") raise def encode_image(self, image: Image.Image) -> torch.Tensor: """Encode image using DINOv2 with registers.""" try: inputs = self.processor(images=image, return_tensors="pt") inputs = {k: v.to(self.device) for k, v in inputs.items()} with torch.no_grad(): outputs = self.model(**inputs) # Use pooler_output for global representation, fallback to mean pooling if hasattr(outputs, 'pooler_output') and outputs.pooler_output is not None: features = outputs.pooler_output else: # Mean pooling over spatial dimensions features = outputs.last_hidden_state.mean(dim=1) features = F.normalize(features, p=2, dim=1) return features except Exception as e: logger.error(f"Failed to encode image with DINOv2 with registers: {e}") raise def encode_image_patches(self, image: Image.Image) -> torch.Tensor: """Encode image patches using DINOv2 with registers.""" try: inputs = self.processor(images=image, return_tensors="pt") inputs = {k: v.to(self.device) for k, v in inputs.items()} with torch.no_grad(): outputs = self.model(**inputs) # Token sequence structure: [CLS] + 4 register tokens + 256 patch tokens = 261 total # We want only the spatial patch tokens (positions 5 to 260) num_register_tokens = 4 expected_patches = (224 // 14) ** 2 # 256 for base model with 224x224 input, patch size 14 # Skip CLS token (position 0) and register tokens (positions 1-4) start_idx = 1 + num_register_tokens # Position 5 end_idx = start_idx + expected_patches # Position 261 patch_features = outputs.last_hidden_state[:, start_idx:end_idx, :] # [1, 256, feature_dim] # Normalize patch features patch_features = F.normalize(patch_features, p=2, dim=-1) return patch_features.squeeze(0) # [num_patches, feature_dim] except Exception as e: logger.error(f"Failed to encode image patches with DINOv2 with registers: {e}") raise def get_model_name(self) -> str: return f"DINOv2-WithRegisters-{self.model_name.split('/')[-1]}" def get_attention_maps(self, image: Image.Image) -> torch.Tensor: """ Extract native attention maps from DINOv2 with registers. Returns: Attention tensor with shape (num_layers, num_heads, num_tokens, num_tokens) where num_tokens includes [CLS] + patches + registers """ try: inputs = self.processor(images=image, return_tensors="pt") inputs = {k: v.to(self.device) for k, v in inputs.items()} with torch.no_grad(): outputs = self.model(**inputs, output_attentions=True) # outputs.attentions is a tuple of attention tensors, one per layer # Each has shape: (batch_size, num_heads, sequence_length, sequence_length) # Stack all layer attentions attention_stack = torch.stack(outputs.attentions) # (num_layers, batch_size, num_heads, seq_len, seq_len) attention_stack = attention_stack.squeeze(1) # Remove batch dimension -> (num_layers, num_heads, seq_len, seq_len) return attention_stack except Exception as e: logger.error(f"Failed to extract attention maps: {e}") raise def compute_cross_attention(self, query_image: Image.Image, candidate_image: Image.Image) -> torch.Tensor: """ Compute cross-attention between query and candidate images using patch features. This uses the extracted patch embeddings to compute attention from query to candidate, similar to the native attention mechanism but across two images. Returns: Cross-attention matrix with shape (query_patches, candidate_patches) """ try: # Get patch features for both images query_patches = self.encode_image_patches(query_image) # (num_query_patches, feature_dim) candidate_patches = self.encode_image_patches(candidate_image) # (num_candidate_patches, feature_dim) # Compute attention-style similarity (softmax over candidate dimension) # attention[i,j] = how much query patch i attends to candidate patch j attention_logits = torch.mm(query_patches, candidate_patches.T) # (query_patches, candidate_patches) # Apply softmax to get attention distribution for each query patch cross_attention = F.softmax(attention_logits, dim=1) return cross_attention except Exception as e: logger.error(f"Failed to compute cross-attention: {e}") raise def supports_native_attention(self) -> bool: """Check if this model supports native attention extraction.""" return True class DINOv3Embedding(EmbeddingModel): """DINOv3-based embedding model from HuggingFace transformers.""" def __init__(self, device: torch.device, model_name: str = "facebook/dinov3-vits16-pretrain-lvd1689m"): super().__init__(device) self.model_name = model_name self.processor = None self.load_model() def load_model(self) -> None: """Load DINOv3 model and preprocessing.""" try: from transformers import AutoModel, AutoImageProcessor logger.info(f"Loading DINOv3 model: {self.model_name}") self.model = AutoModel.from_pretrained(self.model_name) self.model.to(self.device) self.model.eval() self.processor = AutoImageProcessor.from_pretrained(self.model_name) logger.info(f"DINOv3 model {self.model_name} loaded successfully") except Exception as e: logger.error(f"Failed to load DINOv3 model: {e}") raise def encode_image(self, image: Image.Image) -> torch.Tensor: """Encode image using DINOv3.""" try: inputs = self.processor(images=image, return_tensors="pt") inputs = {k: v.to(self.device) for k, v in inputs.items()} with torch.no_grad(): outputs = self.model(**inputs) # Use pooler_output (CLS token) for global representation if hasattr(outputs, 'pooler_output') and outputs.pooler_output is not None: features = outputs.pooler_output else: # Fallback to mean pooling over patch embeddings features = outputs.last_hidden_state[:, 1:, :].mean(dim=1) features = F.normalize(features, p=2, dim=1) return features except Exception as e: logger.error(f"Failed to encode image with DINOv3: {e}") raise def encode_image_patches(self, image: Image.Image) -> torch.Tensor: """Encode image patches using DINOv3.""" try: inputs = self.processor(images=image, return_tensors="pt") inputs = {k: v.to(self.device) for k, v in inputs.items()} with torch.no_grad(): outputs = self.model(**inputs) # DINOv3 outputs: [CLS] + register tokens + patch tokens # We want only the patch tokens (skip CLS at position 0 and register tokens) # For DINOv3-ViTS16, it has 4 register tokens num_register_tokens = 4 patch_features = outputs.last_hidden_state[:, 1 + num_register_tokens:, :] # Normalize patch features patch_features = F.normalize(patch_features, p=2, dim=-1) return patch_features.squeeze(0) # [num_patches, feature_dim] except Exception as e: logger.error(f"Failed to encode image patches with DINOv3: {e}") raise def get_model_name(self) -> str: return f"DINOv3-{self.model_name.split('/')[-1]}" def supports_native_attention(self) -> bool: """Check if this model supports native attention extraction.""" return True def get_attention_maps(self, image: Image.Image) -> torch.Tensor: """ Extract native attention maps from DINOv3. Returns: Attention tensor with shape (num_layers, num_heads, num_tokens, num_tokens) """ try: inputs = self.processor(images=image, return_tensors="pt") inputs = {k: v.to(self.device) for k, v in inputs.items()} with torch.no_grad(): outputs = self.model(**inputs, output_attentions=True) # Stack all layer attentions attention_stack = torch.stack(outputs.attentions) attention_stack = attention_stack.squeeze(1) # Remove batch dimension return attention_stack except Exception as e: logger.error(f"Failed to extract attention maps: {e}") raise def compute_cross_attention(self, query_image: Image.Image, candidate_image: Image.Image) -> torch.Tensor: """ Compute cross-attention between query and candidate images using patch features. Returns: Cross-attention matrix with shape (query_patches, candidate_patches) """ try: query_patches = self.encode_image_patches(query_image) candidate_patches = self.encode_image_patches(candidate_image) # Compute attention-style similarity attention_logits = torch.mm(query_patches, candidate_patches.T) # Apply softmax to get attention distribution cross_attention = F.softmax(attention_logits, dim=1) return cross_attention except Exception as e: logger.error(f"Failed to compute cross-attention: {e}") raise class SigLIPEmbedding(EmbeddingModel): """SigLIP-based embedding model.""" def __init__(self, device: torch.device, model_name: str = "google/siglip-base-patch16-224"): super().__init__(device) self.model_name = model_name self.processor = None self.load_model() def load_model(self) -> None: """Load SigLIP model and preprocessing.""" try: # Check for required dependencies try: import sentencepiece except ImportError: raise ImportError( "SentencePiece is required for SigLIP. Install with: pip install sentencepiece" ) from transformers import SiglipVisionModel, SiglipProcessor logger.info(f"Loading SigLIP model: {self.model_name}") self.model = SiglipVisionModel.from_pretrained(self.model_name) self.model.to(self.device) self.model.eval() self.processor = SiglipProcessor.from_pretrained(self.model_name) logger.info(f"SigLIP model {self.model_name} loaded successfully") except Exception as e: logger.error(f"Failed to load SigLIP model: {e}") raise def encode_image(self, image: Image.Image) -> torch.Tensor: """Encode image using SigLIP.""" try: inputs = self.processor(images=image, return_tensors="pt") inputs = {k: v.to(self.device) for k, v in inputs.items()} with torch.no_grad(): outputs = self.model(**inputs) features = outputs.last_hidden_state.mean(dim=1) # Global average pooling features = F.normalize(features, p=2, dim=1) return features except Exception as e: logger.error(f"Failed to encode image with SigLIP: {e}") raise def encode_image_patches(self, image: Image.Image) -> torch.Tensor: """Encode image patches using SigLIP.""" try: inputs = self.processor(images=image, return_tensors="pt") inputs = {k: v.to(self.device) for k, v in inputs.items()} with torch.no_grad(): outputs = self.model(**inputs) # last_hidden_state contains patch features: [1, num_patches, feature_dim] patch_features = outputs.last_hidden_state # Normalize patch features patch_features = F.normalize(patch_features, p=2, dim=-1) return patch_features.squeeze(0) # [num_patches, feature_dim] except Exception as e: logger.error(f"Failed to encode image patches with SigLIP: {e}") raise def get_model_name(self) -> str: return f"SigLIP-{self.model_name.split('/')[-1]}" class EmbeddingModelFactory: """Factory class for creating embedding models.""" AVAILABLE_MODELS = { "clip": CLIPEmbedding, "dinov2": DINOv2Embedding, "dinov2_registers": DINOv2WithRegistersEmbedding, "dinov3": DINOv3Embedding, "siglip": SigLIPEmbedding, } @classmethod def create_model(cls, model_type: str, device: torch.device, **kwargs) -> EmbeddingModel: """Create an embedding model instance. Args: model_type: Type of model ('clip', 'dinov2', 'dinov2_registers', 'dinov3', 'siglip') device: PyTorch device **kwargs: Additional arguments for specific models Returns: EmbeddingModel instance """ if model_type.lower() not in cls.AVAILABLE_MODELS: raise ValueError(f"Unknown model type: {model_type}. Available: {list(cls.AVAILABLE_MODELS.keys())}") model_class = cls.AVAILABLE_MODELS[model_type.lower()] try: return model_class(device, **kwargs) except Exception as e: logger.error(f"Failed to create {model_type} model: {e}") # Fallback to CLIP if the requested model fails if model_type.lower() != 'clip': logger.info("Falling back to CLIP model") return cls.AVAILABLE_MODELS['clip'](device, **kwargs) else: raise @classmethod def get_available_models(cls) -> List[str]: """Get list of available model types.""" return list(cls.AVAILABLE_MODELS.keys()) def get_default_model_configs() -> Dict[str, Dict[str, Any]]: """Get default configurations for each model type.""" return { "clip": { "model_name": "ViT-B-32", "description": "OpenAI CLIP model - good general purpose vision-language model" }, "dinov2": { "model_name": "dinov2_vitb14", "description": "Meta DINOv2 - self-supervised vision transformer, good for visual features" }, "dinov2_registers": { "model_name": "facebook/dinov2-with-registers-base", "description": "Meta DINOv2 with register tokens - improved feature maps and attention" }, "dinov3": { "model_name": "facebook/dinov3-vits16-pretrain-lvd1689m", "description": "Meta DINOv3 - vision foundation model with high-quality dense features" }, "siglip": { "model_name": "google/siglip-base-patch16-224", "description": "Google SigLIP - improved CLIP-like model with better training" } }