Spaces:
Sleeping
Sleeping
| import numpy as np | |
| import torch | |
| import matplotlib | |
| matplotlib.use('Agg') # Use non-interactive backend for server environments | |
| import matplotlib.pyplot as plt | |
| from PIL import Image | |
| from typing import Tuple, Dict, Any | |
| import io | |
| import base64 | |
| import math | |
| class PatchAttentionAnalyzer: | |
| """Utility class for computing and visualizing patch-level attention between images.""" | |
| def __init__(self, embedding_model): | |
| self.embedding_model = embedding_model | |
| self.supports_native_attention = hasattr(embedding_model, 'supports_native_attention') and embedding_model.supports_native_attention() | |
| def compute_patch_similarities(self, query_image: Image.Image, candidate_image: Image.Image) -> Dict[str, Any]: | |
| """ | |
| Compute patch-level similarities between query and candidate images. | |
| Automatically uses native attention if model supports it. | |
| Returns: | |
| Dictionary containing attention matrix, top correspondences, and metadata | |
| """ | |
| # Use native attention if available | |
| if self.supports_native_attention: | |
| return self.compute_native_attention_similarities(query_image, candidate_image) | |
| # Fallback to cosine similarity approach | |
| try: | |
| # Get patch features for both images | |
| query_patches = self.embedding_model.encode_image_patches(query_image) | |
| candidate_patches = self.embedding_model.encode_image_patches(candidate_image) | |
| # Compute attention matrix | |
| attention_matrix = self.embedding_model.compute_patch_attention(query_patches, candidate_patches) | |
| # Get grid dimensions (assuming square patches for ViT models) | |
| query_grid_size = int(math.sqrt(query_patches.shape[0])) | |
| candidate_grid_size = int(math.sqrt(candidate_patches.shape[0])) | |
| # Find top correspondences for each query patch | |
| top_correspondences = [] | |
| for i in range(attention_matrix.shape[0]): | |
| patch_similarities = attention_matrix[i] | |
| top_indices = torch.topk(patch_similarities, k=min(5, patch_similarities.shape[0])) | |
| top_correspondences.append({ | |
| 'query_patch_idx': i, | |
| 'query_patch_coord': self._patch_idx_to_coord(i, query_grid_size), | |
| 'top_candidate_indices': top_indices.indices.tolist(), | |
| 'top_candidate_coords': [self._patch_idx_to_coord(idx.item(), candidate_grid_size) | |
| for idx in top_indices.indices], | |
| 'similarity_scores': top_indices.values.tolist() | |
| }) | |
| return { | |
| 'attention_matrix': attention_matrix.cpu().numpy(), | |
| 'query_grid_size': query_grid_size, | |
| 'candidate_grid_size': candidate_grid_size, | |
| 'top_correspondences': top_correspondences, | |
| 'query_patches_shape': query_patches.shape, | |
| 'candidate_patches_shape': candidate_patches.shape, | |
| 'overall_similarity': torch.mean(attention_matrix).item() | |
| } | |
| except NotImplementedError: | |
| raise ValueError(f"Patch-level encoding not supported for {self.embedding_model.get_model_name()}") | |
| except Exception as e: | |
| raise RuntimeError(f"Error computing patch similarities: {e}") | |
| def _patch_idx_to_coord(self, patch_idx: int, grid_size: int) -> Tuple[int, int]: | |
| """Convert flat patch index to (row, col) coordinate.""" | |
| row = patch_idx // grid_size | |
| col = patch_idx % grid_size | |
| return (row, col) | |
| def visualize_attention_heatmap(self, query_image: Image.Image, candidate_image: Image.Image, | |
| similarity_data: Dict[str, Any], figsize: Tuple[int, int] = (15, 10)) -> str: | |
| """ | |
| Create a visualization showing attention heatmap between patches. | |
| Returns base64 encoded PNG image. | |
| """ | |
| attention_matrix = similarity_data['attention_matrix'] | |
| query_grid_size = similarity_data['query_grid_size'] | |
| candidate_grid_size = similarity_data['candidate_grid_size'] | |
| fig, axes = plt.subplots(2, 2, figsize=figsize) | |
| fig.suptitle(f'Patch Attention Analysis - Overall Similarity: {similarity_data["overall_similarity"]:.3f}', | |
| fontsize=14, fontweight='bold') | |
| # Plot original images | |
| axes[0, 0].imshow(query_image) | |
| axes[0, 0].set_title('Query Image') | |
| axes[0, 0].axis('off') | |
| self._overlay_patch_grid(axes[0, 0], query_image.size, query_grid_size) | |
| axes[0, 1].imshow(candidate_image) | |
| axes[0, 1].set_title('Candidate Image') | |
| axes[0, 1].axis('off') | |
| self._overlay_patch_grid(axes[0, 1], candidate_image.size, candidate_grid_size) | |
| # Plot attention matrix | |
| im = axes[1, 0].imshow(attention_matrix, cmap='viridis', aspect='auto') | |
| axes[1, 0].set_title('Attention Matrix') | |
| axes[1, 0].set_xlabel('Candidate Patches') | |
| axes[1, 0].set_ylabel('Query Patches') | |
| plt.colorbar(im, ax=axes[1, 0], fraction=0.046, pad=0.04) | |
| # Plot attention summary (max attention per query patch) | |
| max_attention_per_query = np.max(attention_matrix, axis=1) | |
| attention_grid = max_attention_per_query.reshape(query_grid_size, query_grid_size) | |
| im2 = axes[1, 1].imshow(attention_grid, cmap='hot', interpolation='nearest') | |
| axes[1, 1].set_title('Max Attention per Query Patch') | |
| axes[1, 1].set_xlabel('Patch Column') | |
| axes[1, 1].set_ylabel('Patch Row') | |
| plt.colorbar(im2, ax=axes[1, 1], fraction=0.046, pad=0.04) | |
| plt.tight_layout() | |
| # Convert to base64 | |
| buffer = io.BytesIO() | |
| plt.savefig(buffer, format='png', dpi=150, bbox_inches='tight') | |
| buffer.seek(0) | |
| plot_data = buffer.getvalue() | |
| buffer.close() | |
| plt.close() | |
| return base64.b64encode(plot_data).decode() | |
| def visualize_top_correspondences(self, query_image: Image.Image, candidate_image: Image.Image, | |
| similarity_data: Dict[str, Any], num_top_patches: int = 6) -> str: | |
| """ | |
| Visualize the top corresponding patches between query and candidate images. | |
| Returns base64 encoded PNG image. | |
| """ | |
| top_correspondences = similarity_data['top_correspondences'] | |
| query_grid_size = similarity_data['query_grid_size'] | |
| candidate_grid_size = similarity_data['candidate_grid_size'] | |
| # Sort by best similarity score | |
| sorted_correspondences = sorted( | |
| top_correspondences, | |
| key=lambda x: max(x['similarity_scores']), | |
| reverse=True | |
| )[:num_top_patches] | |
| fig, axes = plt.subplots(2, num_top_patches, figsize=(3*num_top_patches, 6)) | |
| fig.suptitle('Top Patch Correspondences', fontsize=14, fontweight='bold') | |
| for i, correspondence in enumerate(sorted_correspondences): | |
| query_coord = correspondence['query_patch_coord'] | |
| best_candidate_coord = correspondence['top_candidate_coords'][0] | |
| best_score = correspondence['similarity_scores'][0] | |
| # Extract and show query patch | |
| query_patch = self._extract_patch_from_image(query_image, query_coord, query_grid_size) | |
| axes[0, i].imshow(query_patch) | |
| axes[0, i].set_title(f'Q-Patch {query_coord}\nScore: {best_score:.3f}') | |
| axes[0, i].axis('off') | |
| # Extract and show best matching candidate patch | |
| candidate_patch = self._extract_patch_from_image(candidate_image, best_candidate_coord, candidate_grid_size) | |
| axes[1, i].imshow(candidate_patch) | |
| axes[1, i].set_title(f'C-Patch {best_candidate_coord}') | |
| axes[1, i].axis('off') | |
| plt.tight_layout() | |
| # Convert to base64 | |
| buffer = io.BytesIO() | |
| plt.savefig(buffer, format='png', dpi=150, bbox_inches='tight') | |
| buffer.seek(0) | |
| plot_data = buffer.getvalue() | |
| buffer.close() | |
| plt.close() | |
| return base64.b64encode(plot_data).decode() | |
| def _overlay_patch_grid(self, ax, image_size: Tuple[int, int], grid_size: int): | |
| """Overlay patch grid lines on image.""" | |
| width, height = image_size | |
| patch_width = width / grid_size | |
| patch_height = height / grid_size | |
| # Draw vertical lines | |
| for i in range(1, grid_size): | |
| x = i * patch_width | |
| ax.axvline(x=x, color='white', alpha=0.5, linewidth=1) | |
| # Draw horizontal lines | |
| for i in range(1, grid_size): | |
| y = i * patch_height | |
| ax.axhline(y=y, color='white', alpha=0.5, linewidth=1) | |
| def _extract_patch_from_image(self, image: Image.Image, patch_coord: Tuple[int, int], grid_size: int) -> Image.Image: | |
| """Extract a specific patch from an image based on grid coordinates.""" | |
| row, col = patch_coord | |
| width, height = image.size | |
| patch_width = width // grid_size | |
| patch_height = height // grid_size | |
| left = col * patch_width | |
| top = row * patch_height | |
| right = min((col + 1) * patch_width, width) | |
| bottom = min((row + 1) * patch_height, height) | |
| return image.crop((left, top, right, bottom)) | |
| def compute_native_attention_similarities(self, query_image: Image.Image, candidate_image: Image.Image) -> Dict[str, Any]: | |
| """ | |
| Compute patch-level similarities using native attention mechanism. | |
| Only available for models with native attention support (e.g., DINOv2 with registers). | |
| Returns: | |
| Dictionary containing attention matrix, top correspondences, and metadata | |
| """ | |
| try: | |
| # Use model's cross-attention computation | |
| attention_matrix = self.embedding_model.compute_cross_attention(query_image, candidate_image) | |
| attention_matrix_np = attention_matrix.cpu().numpy() | |
| # Get patch counts (attention_matrix is already query_patches x candidate_patches) | |
| num_query_patches = attention_matrix.shape[0] | |
| num_candidate_patches = attention_matrix.shape[1] | |
| # Get grid dimensions (assuming square patches) | |
| query_grid_size = int(math.sqrt(num_query_patches)) | |
| candidate_grid_size = int(math.sqrt(num_candidate_patches)) | |
| # Find top correspondences for each query patch | |
| top_correspondences = [] | |
| for i in range(num_query_patches): | |
| patch_similarities = attention_matrix[i] | |
| top_indices = torch.topk(patch_similarities, k=min(5, num_candidate_patches)) | |
| top_correspondences.append({ | |
| 'query_patch_idx': i, | |
| 'query_patch_coord': self._patch_idx_to_coord(i, query_grid_size), | |
| 'top_candidate_indices': top_indices.indices.tolist(), | |
| 'top_candidate_coords': [self._patch_idx_to_coord(idx.item(), candidate_grid_size) | |
| for idx in top_indices.indices], | |
| 'similarity_scores': top_indices.values.tolist() | |
| }) | |
| return { | |
| 'attention_matrix': attention_matrix_np, | |
| 'query_grid_size': query_grid_size, | |
| 'candidate_grid_size': candidate_grid_size, | |
| 'top_correspondences': top_correspondences, | |
| 'query_patches_shape': (num_query_patches, attention_matrix.shape[-1]), | |
| 'candidate_patches_shape': (num_candidate_patches, attention_matrix.shape[-1]), | |
| 'overall_similarity': torch.mean(attention_matrix).item(), | |
| 'use_native_attention': True | |
| } | |
| except Exception as e: | |
| raise RuntimeError(f"Error computing native attention similarities: {e}") | |
| def get_similarity_summary(self, similarity_data: Dict[str, Any]) -> Dict[str, Any]: | |
| """Get a summary of similarity statistics.""" | |
| attention_matrix = similarity_data['attention_matrix'] | |
| summary = { | |
| 'overall_similarity': similarity_data['overall_similarity'], | |
| 'max_similarity': float(np.max(attention_matrix)), | |
| 'min_similarity': float(np.min(attention_matrix)), | |
| 'std_similarity': float(np.std(attention_matrix)), | |
| 'query_patches_count': similarity_data['query_patches_shape'][0], | |
| 'candidate_patches_count': similarity_data['candidate_patches_shape'][0], | |
| 'high_attention_patches': int(np.sum(attention_matrix > (np.mean(attention_matrix) + np.std(attention_matrix)))), | |
| 'model_name': self.embedding_model.get_model_name() | |
| } | |
| # Add native attention flag if present | |
| if 'use_native_attention' in similarity_data: | |
| summary['use_native_attention'] = similarity_data['use_native_attention'] | |
| return summary | |
| def visualize_multihead_attention(self, image: Image.Image, layer_idx: int = -1, figsize: Tuple[int, int] = (20, 12)) -> str: | |
| """ | |
| Visualize attention from multiple heads for a single image. | |
| Only available for models with native attention support. | |
| Args: | |
| image: Input image to visualize attention for | |
| layer_idx: Which transformer layer to visualize (-1 for last layer) | |
| figsize: Figure size for the plot | |
| Returns: | |
| Base64 encoded PNG image showing multi-head attention patterns | |
| """ | |
| if not self.supports_native_attention: | |
| raise ValueError("Multi-head attention visualization requires native attention support") | |
| try: | |
| # Get attention maps from the model | |
| attention_maps = self.embedding_model.get_attention_maps(image) | |
| # Shape: (num_layers, num_heads, num_tokens, num_tokens) | |
| # Select the specified layer | |
| layer_attention = attention_maps[layer_idx] # (num_heads, num_tokens, num_tokens) | |
| num_heads = layer_attention.shape[0] | |
| # Extract patch-to-patch attention (exclude CLS token and register tokens) | |
| # Token sequence structure varies by model: | |
| # DINOv2 with registers: [CLS] + 4 register tokens + 256 spatial patches = 261 total | |
| # DINOv3: [CLS] + 4 register tokens + 196 spatial patches (16x16 patches) = 201 total | |
| model_name = self.embedding_model.get_model_name().lower() | |
| if 'dinov3' in model_name: | |
| num_register_tokens = 4 | |
| expected_patches = 196 # For 224x224 image with patch size 16 (14*14=196) | |
| else: | |
| num_register_tokens = 4 | |
| expected_patches = 256 # For 224x224 image with patch size 14 | |
| # Skip CLS token (position 0) and register tokens (positions 1-4) | |
| start_idx = 1 + num_register_tokens # Position 5 | |
| end_idx = start_idx + expected_patches # Position 261 | |
| patch_attention = layer_attention[:, start_idx:end_idx, start_idx:end_idx] | |
| # Convert to numpy | |
| patch_attention_np = patch_attention.cpu().numpy() | |
| # Get grid size | |
| num_patches = patch_attention.shape[1] | |
| grid_size = int(math.sqrt(num_patches)) | |
| # Create subplot grid | |
| num_cols = 4 | |
| num_rows = (num_heads + num_cols - 1) // num_cols # Ceiling division | |
| fig, axes = plt.subplots(num_rows, num_cols, figsize=figsize) | |
| axes = axes.flatten() if num_heads > 1 else [axes] | |
| layer_name = f"Layer {layer_idx}" if layer_idx >= 0 else f"Last Layer ({len(attention_maps)})" | |
| fig.suptitle(f'Multi-Head Attention Patterns - {layer_name}', fontsize=16, fontweight='bold') | |
| # Plot each head's average attention | |
| for head_idx in range(num_heads): | |
| # Average attention from all query patches to all key patches | |
| head_attn = patch_attention_np[head_idx] | |
| avg_attention = np.mean(head_attn, axis=0).reshape(grid_size, grid_size) | |
| im = axes[head_idx].imshow(avg_attention, cmap='viridis', interpolation='nearest') | |
| axes[head_idx].set_title(f'Head {head_idx + 1}') | |
| axes[head_idx].axis('off') | |
| plt.colorbar(im, ax=axes[head_idx], fraction=0.046, pad=0.04) | |
| # Hide unused subplots | |
| for idx in range(num_heads, len(axes)): | |
| axes[idx].axis('off') | |
| plt.tight_layout() | |
| # Convert to base64 | |
| buffer = io.BytesIO() | |
| plt.savefig(buffer, format='png', dpi=150, bbox_inches='tight') | |
| buffer.seek(0) | |
| plot_data = buffer.getvalue() | |
| buffer.close() | |
| plt.close() | |
| return base64.b64encode(plot_data).decode() | |
| except Exception as e: | |
| raise RuntimeError(f"Error visualizing multi-head attention: {e}") | |
| def visualize_attention_comparison(self, query_image: Image.Image, candidate_image: Image.Image, | |
| figsize: Tuple[int, int] = (20, 10)) -> str: | |
| """ | |
| Compare native attention vs computed cosine similarity side-by-side. | |
| Only available for models with native attention support. | |
| Args: | |
| query_image: Query image | |
| candidate_image: Candidate image | |
| figsize: Figure size for the plot | |
| Returns: | |
| Base64 encoded PNG showing both attention methods | |
| """ | |
| if not self.supports_native_attention: | |
| raise ValueError("Attention comparison requires native attention support") | |
| try: | |
| # Compute native attention | |
| native_data = self.compute_native_attention_similarities(query_image, candidate_image) | |
| # Compute cosine similarity for comparison | |
| query_patches = self.embedding_model.encode_image_patches(query_image) | |
| candidate_patches = self.embedding_model.encode_image_patches(candidate_image) | |
| cosine_attention = self.embedding_model.compute_patch_attention(query_patches, candidate_patches) | |
| cosine_attention_np = cosine_attention.cpu().numpy() | |
| # Create comparison visualization | |
| fig, axes = plt.subplots(2, 3, figsize=figsize) | |
| fig.suptitle('Native Attention vs Cosine Similarity Comparison', fontsize=16, fontweight='bold') | |
| # Row 1: Native attention | |
| axes[0, 0].imshow(query_image) | |
| axes[0, 0].set_title('Query Image') | |
| axes[0, 0].axis('off') | |
| im1 = axes[0, 1].imshow(native_data['attention_matrix'], cmap='viridis', aspect='auto') | |
| axes[0, 1].set_title(f'Native Attention\n(Avg: {native_data["overall_similarity"]:.3f})') | |
| axes[0, 1].set_xlabel('Candidate Patches') | |
| axes[0, 1].set_ylabel('Query Patches') | |
| plt.colorbar(im1, ax=axes[0, 1], fraction=0.046, pad=0.04) | |
| # Max attention heatmap for native | |
| max_native = np.max(native_data['attention_matrix'], axis=1) | |
| native_grid = max_native.reshape(native_data['query_grid_size'], native_data['query_grid_size']) | |
| im2 = axes[0, 2].imshow(native_grid, cmap='hot', interpolation='nearest') | |
| axes[0, 2].set_title('Max Native Attention per Patch') | |
| plt.colorbar(im2, ax=axes[0, 2], fraction=0.046, pad=0.04) | |
| # Row 2: Cosine similarity | |
| axes[1, 0].imshow(candidate_image) | |
| axes[1, 0].set_title('Candidate Image') | |
| axes[1, 0].axis('off') | |
| cosine_mean = float(np.mean(cosine_attention_np)) | |
| im3 = axes[1, 1].imshow(cosine_attention_np, cmap='viridis', aspect='auto') | |
| axes[1, 1].set_title(f'Cosine Similarity\n(Avg: {cosine_mean:.3f})') | |
| axes[1, 1].set_xlabel('Candidate Patches') | |
| axes[1, 1].set_ylabel('Query Patches') | |
| plt.colorbar(im3, ax=axes[1, 1], fraction=0.046, pad=0.04) | |
| # Max attention heatmap for cosine | |
| max_cosine = np.max(cosine_attention_np, axis=1) | |
| query_grid_size = int(math.sqrt(query_patches.shape[0])) | |
| cosine_grid = max_cosine.reshape(query_grid_size, query_grid_size) | |
| im4 = axes[1, 2].imshow(cosine_grid, cmap='hot', interpolation='nearest') | |
| axes[1, 2].set_title('Max Cosine Similarity per Patch') | |
| plt.colorbar(im4, ax=axes[1, 2], fraction=0.046, pad=0.04) | |
| plt.tight_layout() | |
| # Convert to base64 | |
| buffer = io.BytesIO() | |
| plt.savefig(buffer, format='png', dpi=150, bbox_inches='tight') | |
| buffer.seek(0) | |
| plot_data = buffer.getvalue() | |
| buffer.close() | |
| plt.close() | |
| return base64.b64encode(plot_data).decode() | |
| except Exception as e: | |
| raise RuntimeError(f"Error comparing attention methods: {e}") |