Spaces:
Sleeping
Sleeping
| from abc import ABC, abstractmethod | |
| from typing import Dict, Any, List | |
| import torch | |
| import torch.nn.functional as F | |
| from PIL import Image | |
| import logging | |
| logger = logging.getLogger(__name__) | |
| class EmbeddingModel(ABC): | |
| """Abstract base class for embedding models.""" | |
| def __init__(self, device: torch.device): | |
| self.device = device | |
| self.model = None | |
| self.preprocess = None | |
| def load_model(self) -> None: | |
| """Load the embedding model and preprocessing.""" | |
| pass | |
| def encode_image(self, image: Image.Image) -> torch.Tensor: | |
| """Encode an image into feature vector.""" | |
| pass | |
| def encode_image_patches(self, image: Image.Image) -> torch.Tensor: | |
| """Encode an image into patch-level features. Override in subclasses that support it.""" | |
| raise NotImplementedError("Patch-level encoding not implemented for this model") | |
| def compute_patch_attention(self, query_patches: torch.Tensor, candidate_patches: torch.Tensor) -> torch.Tensor: | |
| """Compute attention weights between query and candidate patches.""" | |
| # query_patches: [num_query_patches, feature_dim] | |
| # candidate_patches: [num_candidate_patches, feature_dim] | |
| # Normalize patches | |
| query_patches = F.normalize(query_patches, p=2, dim=1) | |
| candidate_patches = F.normalize(candidate_patches, p=2, dim=1) | |
| # Compute attention matrix: [num_query_patches, num_candidate_patches] | |
| attention_matrix = torch.mm(query_patches, candidate_patches.T) | |
| return attention_matrix | |
| def get_model_name(self) -> str: | |
| """Return the model name.""" | |
| pass | |
| def compute_similarity(self, query_features: torch.Tensor, candidate_features: torch.Tensor) -> float: | |
| """Compute similarity between query and candidate features.""" | |
| return torch.mm(query_features, candidate_features.T).item() | |
| class CLIPEmbedding(EmbeddingModel): | |
| """CLIP-based embedding model.""" | |
| def __init__(self, device: torch.device, model_name: str = "ViT-B-32"): | |
| super().__init__(device) | |
| self.model_name = model_name | |
| self.tokenizer = None | |
| self.load_model() | |
| def load_model(self) -> None: | |
| """Load CLIP model and preprocessing.""" | |
| try: | |
| import open_clip | |
| logger.info(f"Loading CLIP model: {self.model_name}") | |
| self.model, _, self.preprocess = open_clip.create_model_and_transforms( | |
| self.model_name, pretrained="openai" | |
| ) | |
| self.model.to(self.device) | |
| self.tokenizer = open_clip.get_tokenizer(self.model_name) | |
| logger.info(f"CLIP model {self.model_name} loaded successfully") | |
| except Exception as e: | |
| logger.error(f"Failed to load CLIP model: {e}") | |
| raise | |
| def encode_image(self, image: Image.Image) -> torch.Tensor: | |
| """Encode image using CLIP.""" | |
| try: | |
| image_input = self.preprocess(image).unsqueeze(0).to(self.device) | |
| with torch.no_grad(): | |
| features = self.model.encode_image(image_input) | |
| features = F.normalize(features, p=2, dim=1) | |
| return features | |
| except Exception as e: | |
| logger.error(f"Failed to encode image with CLIP: {e}") | |
| raise | |
| def encode_image_patches(self, image: Image.Image) -> torch.Tensor: | |
| """Encode image patches using CLIP vision transformer.""" | |
| try: | |
| image_input = self.preprocess(image).unsqueeze(0).to(self.device) | |
| with torch.no_grad(): | |
| # Get patch features from CLIP vision transformer | |
| vision_model = self.model.visual | |
| # Pass through patch embedding and positional encoding | |
| x = vision_model.conv1(image_input) # shape = [*, width, grid, grid] | |
| x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] | |
| x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] | |
| # Add class token and positional embeddings | |
| x = torch.cat([vision_model.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) | |
| x = x + vision_model.positional_embedding.to(x.dtype) | |
| # Apply layer norm | |
| x = vision_model.ln_pre(x) | |
| x = x.permute(1, 0, 2) # NLD -> LND | |
| # Pass through transformer blocks | |
| for block in vision_model.transformer.resblocks: | |
| x = block(x) | |
| x = x.permute(1, 0, 2) # LND -> NLD | |
| # Remove class token to get only patch features | |
| patch_features = x[:, 1:, :] # [1, num_patches, feature_dim] | |
| patch_features = vision_model.ln_post(patch_features) | |
| # Apply projection if it exists | |
| if vision_model.proj is not None: | |
| patch_features = patch_features @ vision_model.proj | |
| # Normalize patch features | |
| patch_features = F.normalize(patch_features, p=2, dim=-1) | |
| return patch_features.squeeze(0) # [num_patches, feature_dim] | |
| except Exception as e: | |
| logger.error(f"Failed to encode image patches with CLIP: {e}") | |
| raise | |
| def get_model_name(self) -> str: | |
| return f"CLIP-{self.model_name}" | |
| class DINOv2Embedding(EmbeddingModel): | |
| """DINOv2-based embedding model.""" | |
| def __init__(self, device: torch.device, model_name: str = "dinov2_vitb14"): | |
| super().__init__(device) | |
| self.model_name = model_name | |
| self.load_model() | |
| def load_model(self) -> None: | |
| """Load DINOv2 model and preprocessing.""" | |
| try: | |
| import torch.hub | |
| from torchvision import transforms | |
| logger.info(f"Loading DINOv2 model: {self.model_name}") | |
| # Load DINOv2 model from torch hub | |
| self.model = torch.hub.load('facebookresearch/dinov2', self.model_name) | |
| self.model.to(self.device) | |
| self.model.eval() | |
| # DINOv2 preprocessing | |
| self.preprocess = transforms.Compose([ | |
| transforms.Resize(256, interpolation=transforms.InterpolationMode.BICUBIC), | |
| transforms.CenterCrop(224), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |
| ]) | |
| logger.info(f"DINOv2 model {self.model_name} loaded successfully") | |
| except Exception as e: | |
| logger.error(f"Failed to load DINOv2 model: {e}") | |
| raise | |
| def encode_image(self, image: Image.Image) -> torch.Tensor: | |
| """Encode image using DINOv2.""" | |
| try: | |
| image_input = self.preprocess(image).unsqueeze(0).to(self.device) | |
| with torch.no_grad(): | |
| features = self.model(image_input) | |
| features = F.normalize(features, p=2, dim=1) | |
| return features | |
| except Exception as e: | |
| logger.error(f"Failed to encode image with DINOv2: {e}") | |
| raise | |
| def encode_image_patches(self, image: Image.Image) -> torch.Tensor: | |
| """Encode image patches using DINOv2.""" | |
| try: | |
| image_input = self.preprocess(image).unsqueeze(0).to(self.device) | |
| with torch.no_grad(): | |
| # Get patch features from DINOv2 | |
| # DINOv2 forward_features returns dict with 'x_norm_patchtokens' containing patch features | |
| features_dict = self.model.forward_features(image_input) | |
| patch_features = features_dict['x_norm_patchtokens'] # [1, num_patches, feature_dim] | |
| # Normalize patch features | |
| patch_features = F.normalize(patch_features, p=2, dim=-1) | |
| return patch_features.squeeze(0) # [num_patches, feature_dim] | |
| except Exception as e: | |
| logger.error(f"Failed to encode image patches with DINOv2: {e}") | |
| raise | |
| def get_model_name(self) -> str: | |
| return f"DINOv2-{self.model_name}" | |
| class DINOv2WithRegistersEmbedding(EmbeddingModel): | |
| """DINOv2 with register tokens - improved feature maps and attention.""" | |
| def __init__(self, device: torch.device, model_name: str = "facebook/dinov2-with-registers-base"): | |
| super().__init__(device) | |
| self.model_name = model_name | |
| self.processor = None | |
| self.load_model() | |
| def load_model(self) -> None: | |
| """Load DINOv2 with registers model and preprocessing.""" | |
| try: | |
| from transformers import Dinov2WithRegistersModel, AutoImageProcessor | |
| logger.info(f"Loading DINOv2 with registers model: {self.model_name}") | |
| self.model = Dinov2WithRegistersModel.from_pretrained(self.model_name) | |
| self.model.to(self.device) | |
| self.model.eval() | |
| self.processor = AutoImageProcessor.from_pretrained(self.model_name) | |
| logger.info(f"DINOv2 with registers model {self.model_name} loaded successfully") | |
| except Exception as e: | |
| logger.error(f"Failed to load DINOv2 with registers model: {e}") | |
| raise | |
| def encode_image(self, image: Image.Image) -> torch.Tensor: | |
| """Encode image using DINOv2 with registers.""" | |
| try: | |
| inputs = self.processor(images=image, return_tensors="pt") | |
| inputs = {k: v.to(self.device) for k, v in inputs.items()} | |
| with torch.no_grad(): | |
| outputs = self.model(**inputs) | |
| # Use pooler_output for global representation, fallback to mean pooling | |
| if hasattr(outputs, 'pooler_output') and outputs.pooler_output is not None: | |
| features = outputs.pooler_output | |
| else: | |
| # Mean pooling over spatial dimensions | |
| features = outputs.last_hidden_state.mean(dim=1) | |
| features = F.normalize(features, p=2, dim=1) | |
| return features | |
| except Exception as e: | |
| logger.error(f"Failed to encode image with DINOv2 with registers: {e}") | |
| raise | |
| def encode_image_patches(self, image: Image.Image) -> torch.Tensor: | |
| """Encode image patches using DINOv2 with registers.""" | |
| try: | |
| inputs = self.processor(images=image, return_tensors="pt") | |
| inputs = {k: v.to(self.device) for k, v in inputs.items()} | |
| with torch.no_grad(): | |
| outputs = self.model(**inputs) | |
| # Token sequence structure: [CLS] + 4 register tokens + 256 patch tokens = 261 total | |
| # We want only the spatial patch tokens (positions 5 to 260) | |
| num_register_tokens = 4 | |
| expected_patches = (224 // 14) ** 2 # 256 for base model with 224x224 input, patch size 14 | |
| # Skip CLS token (position 0) and register tokens (positions 1-4) | |
| start_idx = 1 + num_register_tokens # Position 5 | |
| end_idx = start_idx + expected_patches # Position 261 | |
| patch_features = outputs.last_hidden_state[:, start_idx:end_idx, :] # [1, 256, feature_dim] | |
| # Normalize patch features | |
| patch_features = F.normalize(patch_features, p=2, dim=-1) | |
| return patch_features.squeeze(0) # [num_patches, feature_dim] | |
| except Exception as e: | |
| logger.error(f"Failed to encode image patches with DINOv2 with registers: {e}") | |
| raise | |
| def get_model_name(self) -> str: | |
| return f"DINOv2-WithRegisters-{self.model_name.split('/')[-1]}" | |
| def get_attention_maps(self, image: Image.Image) -> torch.Tensor: | |
| """ | |
| Extract native attention maps from DINOv2 with registers. | |
| Returns: | |
| Attention tensor with shape (num_layers, num_heads, num_tokens, num_tokens) | |
| where num_tokens includes [CLS] + patches + registers | |
| """ | |
| try: | |
| inputs = self.processor(images=image, return_tensors="pt") | |
| inputs = {k: v.to(self.device) for k, v in inputs.items()} | |
| with torch.no_grad(): | |
| outputs = self.model(**inputs, output_attentions=True) | |
| # outputs.attentions is a tuple of attention tensors, one per layer | |
| # Each has shape: (batch_size, num_heads, sequence_length, sequence_length) | |
| # Stack all layer attentions | |
| attention_stack = torch.stack(outputs.attentions) # (num_layers, batch_size, num_heads, seq_len, seq_len) | |
| attention_stack = attention_stack.squeeze(1) # Remove batch dimension -> (num_layers, num_heads, seq_len, seq_len) | |
| return attention_stack | |
| except Exception as e: | |
| logger.error(f"Failed to extract attention maps: {e}") | |
| raise | |
| def compute_cross_attention(self, query_image: Image.Image, candidate_image: Image.Image) -> torch.Tensor: | |
| """ | |
| Compute cross-attention between query and candidate images using patch features. | |
| This uses the extracted patch embeddings to compute attention from query to candidate, | |
| similar to the native attention mechanism but across two images. | |
| Returns: | |
| Cross-attention matrix with shape (query_patches, candidate_patches) | |
| """ | |
| try: | |
| # Get patch features for both images | |
| query_patches = self.encode_image_patches(query_image) # (num_query_patches, feature_dim) | |
| candidate_patches = self.encode_image_patches(candidate_image) # (num_candidate_patches, feature_dim) | |
| # Compute attention-style similarity (softmax over candidate dimension) | |
| # attention[i,j] = how much query patch i attends to candidate patch j | |
| attention_logits = torch.mm(query_patches, candidate_patches.T) # (query_patches, candidate_patches) | |
| # Apply softmax to get attention distribution for each query patch | |
| cross_attention = F.softmax(attention_logits, dim=1) | |
| return cross_attention | |
| except Exception as e: | |
| logger.error(f"Failed to compute cross-attention: {e}") | |
| raise | |
| def supports_native_attention(self) -> bool: | |
| """Check if this model supports native attention extraction.""" | |
| return True | |
| class DINOv3Embedding(EmbeddingModel): | |
| """DINOv3-based embedding model from HuggingFace transformers.""" | |
| def __init__(self, device: torch.device, model_name: str = "facebook/dinov3-vits16-pretrain-lvd1689m"): | |
| super().__init__(device) | |
| self.model_name = model_name | |
| self.processor = None | |
| self.load_model() | |
| def load_model(self) -> None: | |
| """Load DINOv3 model and preprocessing.""" | |
| try: | |
| from transformers import AutoModel, AutoImageProcessor | |
| logger.info(f"Loading DINOv3 model: {self.model_name}") | |
| self.model = AutoModel.from_pretrained(self.model_name) | |
| self.model.to(self.device) | |
| self.model.eval() | |
| self.processor = AutoImageProcessor.from_pretrained(self.model_name) | |
| logger.info(f"DINOv3 model {self.model_name} loaded successfully") | |
| except Exception as e: | |
| logger.error(f"Failed to load DINOv3 model: {e}") | |
| raise | |
| def encode_image(self, image: Image.Image) -> torch.Tensor: | |
| """Encode image using DINOv3.""" | |
| try: | |
| inputs = self.processor(images=image, return_tensors="pt") | |
| inputs = {k: v.to(self.device) for k, v in inputs.items()} | |
| with torch.no_grad(): | |
| outputs = self.model(**inputs) | |
| # Use pooler_output (CLS token) for global representation | |
| if hasattr(outputs, 'pooler_output') and outputs.pooler_output is not None: | |
| features = outputs.pooler_output | |
| else: | |
| # Fallback to mean pooling over patch embeddings | |
| features = outputs.last_hidden_state[:, 1:, :].mean(dim=1) | |
| features = F.normalize(features, p=2, dim=1) | |
| return features | |
| except Exception as e: | |
| logger.error(f"Failed to encode image with DINOv3: {e}") | |
| raise | |
| def encode_image_patches(self, image: Image.Image) -> torch.Tensor: | |
| """Encode image patches using DINOv3.""" | |
| try: | |
| inputs = self.processor(images=image, return_tensors="pt") | |
| inputs = {k: v.to(self.device) for k, v in inputs.items()} | |
| with torch.no_grad(): | |
| outputs = self.model(**inputs) | |
| # DINOv3 outputs: [CLS] + register tokens + patch tokens | |
| # We want only the patch tokens (skip CLS at position 0 and register tokens) | |
| # For DINOv3-ViTS16, it has 4 register tokens | |
| num_register_tokens = 4 | |
| patch_features = outputs.last_hidden_state[:, 1 + num_register_tokens:, :] | |
| # Normalize patch features | |
| patch_features = F.normalize(patch_features, p=2, dim=-1) | |
| return patch_features.squeeze(0) # [num_patches, feature_dim] | |
| except Exception as e: | |
| logger.error(f"Failed to encode image patches with DINOv3: {e}") | |
| raise | |
| def get_model_name(self) -> str: | |
| return f"DINOv3-{self.model_name.split('/')[-1]}" | |
| def supports_native_attention(self) -> bool: | |
| """Check if this model supports native attention extraction.""" | |
| return True | |
| def get_attention_maps(self, image: Image.Image) -> torch.Tensor: | |
| """ | |
| Extract native attention maps from DINOv3. | |
| Returns: | |
| Attention tensor with shape (num_layers, num_heads, num_tokens, num_tokens) | |
| """ | |
| try: | |
| inputs = self.processor(images=image, return_tensors="pt") | |
| inputs = {k: v.to(self.device) for k, v in inputs.items()} | |
| with torch.no_grad(): | |
| outputs = self.model(**inputs, output_attentions=True) | |
| # Stack all layer attentions | |
| attention_stack = torch.stack(outputs.attentions) | |
| attention_stack = attention_stack.squeeze(1) # Remove batch dimension | |
| return attention_stack | |
| except Exception as e: | |
| logger.error(f"Failed to extract attention maps: {e}") | |
| raise | |
| def compute_cross_attention(self, query_image: Image.Image, candidate_image: Image.Image) -> torch.Tensor: | |
| """ | |
| Compute cross-attention between query and candidate images using patch features. | |
| Returns: | |
| Cross-attention matrix with shape (query_patches, candidate_patches) | |
| """ | |
| try: | |
| query_patches = self.encode_image_patches(query_image) | |
| candidate_patches = self.encode_image_patches(candidate_image) | |
| # Compute attention-style similarity | |
| attention_logits = torch.mm(query_patches, candidate_patches.T) | |
| # Apply softmax to get attention distribution | |
| cross_attention = F.softmax(attention_logits, dim=1) | |
| return cross_attention | |
| except Exception as e: | |
| logger.error(f"Failed to compute cross-attention: {e}") | |
| raise | |
| class SigLIPEmbedding(EmbeddingModel): | |
| """SigLIP-based embedding model.""" | |
| def __init__(self, device: torch.device, model_name: str = "google/siglip-base-patch16-224"): | |
| super().__init__(device) | |
| self.model_name = model_name | |
| self.processor = None | |
| self.load_model() | |
| def load_model(self) -> None: | |
| """Load SigLIP model and preprocessing.""" | |
| try: | |
| # Check for required dependencies | |
| try: | |
| import sentencepiece | |
| except ImportError: | |
| raise ImportError( | |
| "SentencePiece is required for SigLIP. Install with: pip install sentencepiece" | |
| ) | |
| from transformers import SiglipVisionModel, SiglipProcessor | |
| logger.info(f"Loading SigLIP model: {self.model_name}") | |
| self.model = SiglipVisionModel.from_pretrained(self.model_name) | |
| self.model.to(self.device) | |
| self.model.eval() | |
| self.processor = SiglipProcessor.from_pretrained(self.model_name) | |
| logger.info(f"SigLIP model {self.model_name} loaded successfully") | |
| except Exception as e: | |
| logger.error(f"Failed to load SigLIP model: {e}") | |
| raise | |
| def encode_image(self, image: Image.Image) -> torch.Tensor: | |
| """Encode image using SigLIP.""" | |
| try: | |
| inputs = self.processor(images=image, return_tensors="pt") | |
| inputs = {k: v.to(self.device) for k, v in inputs.items()} | |
| with torch.no_grad(): | |
| outputs = self.model(**inputs) | |
| features = outputs.last_hidden_state.mean(dim=1) # Global average pooling | |
| features = F.normalize(features, p=2, dim=1) | |
| return features | |
| except Exception as e: | |
| logger.error(f"Failed to encode image with SigLIP: {e}") | |
| raise | |
| def encode_image_patches(self, image: Image.Image) -> torch.Tensor: | |
| """Encode image patches using SigLIP.""" | |
| try: | |
| inputs = self.processor(images=image, return_tensors="pt") | |
| inputs = {k: v.to(self.device) for k, v in inputs.items()} | |
| with torch.no_grad(): | |
| outputs = self.model(**inputs) | |
| # last_hidden_state contains patch features: [1, num_patches, feature_dim] | |
| patch_features = outputs.last_hidden_state | |
| # Normalize patch features | |
| patch_features = F.normalize(patch_features, p=2, dim=-1) | |
| return patch_features.squeeze(0) # [num_patches, feature_dim] | |
| except Exception as e: | |
| logger.error(f"Failed to encode image patches with SigLIP: {e}") | |
| raise | |
| def get_model_name(self) -> str: | |
| return f"SigLIP-{self.model_name.split('/')[-1]}" | |
| class EmbeddingModelFactory: | |
| """Factory class for creating embedding models.""" | |
| AVAILABLE_MODELS = { | |
| "clip": CLIPEmbedding, | |
| "dinov2": DINOv2Embedding, | |
| "dinov2_registers": DINOv2WithRegistersEmbedding, | |
| "dinov3": DINOv3Embedding, | |
| "siglip": SigLIPEmbedding, | |
| } | |
| def create_model(cls, model_type: str, device: torch.device, **kwargs) -> EmbeddingModel: | |
| """Create an embedding model instance. | |
| Args: | |
| model_type: Type of model ('clip', 'dinov2', 'dinov2_registers', 'dinov3', 'siglip') | |
| device: PyTorch device | |
| **kwargs: Additional arguments for specific models | |
| Returns: | |
| EmbeddingModel instance | |
| """ | |
| if model_type.lower() not in cls.AVAILABLE_MODELS: | |
| raise ValueError(f"Unknown model type: {model_type}. Available: {list(cls.AVAILABLE_MODELS.keys())}") | |
| model_class = cls.AVAILABLE_MODELS[model_type.lower()] | |
| try: | |
| return model_class(device, **kwargs) | |
| except Exception as e: | |
| logger.error(f"Failed to create {model_type} model: {e}") | |
| # Fallback to CLIP if the requested model fails | |
| if model_type.lower() != 'clip': | |
| logger.info("Falling back to CLIP model") | |
| return cls.AVAILABLE_MODELS['clip'](device, **kwargs) | |
| else: | |
| raise | |
| def get_available_models(cls) -> List[str]: | |
| """Get list of available model types.""" | |
| return list(cls.AVAILABLE_MODELS.keys()) | |
| def get_default_model_configs() -> Dict[str, Dict[str, Any]]: | |
| """Get default configurations for each model type.""" | |
| return { | |
| "clip": { | |
| "model_name": "ViT-B-32", | |
| "description": "OpenAI CLIP model - good general purpose vision-language model" | |
| }, | |
| "dinov2": { | |
| "model_name": "dinov2_vitb14", | |
| "description": "Meta DINOv2 - self-supervised vision transformer, good for visual features" | |
| }, | |
| "dinov2_registers": { | |
| "model_name": "facebook/dinov2-with-registers-base", | |
| "description": "Meta DINOv2 with register tokens - improved feature maps and attention" | |
| }, | |
| "dinov3": { | |
| "model_name": "facebook/dinov3-vits16-pretrain-lvd1689m", | |
| "description": "Meta DINOv3 - vision foundation model with high-quality dense features" | |
| }, | |
| "siglip": { | |
| "model_name": "google/siglip-base-patch16-224", | |
| "description": "Google SigLIP - improved CLIP-like model with better training" | |
| } | |
| } |