tattoo_search_engine / patch_attention.py
Onur Çopur
add dinov3 and dinov2 with registers
0647d62
import numpy as np
import torch
import matplotlib
matplotlib.use('Agg') # Use non-interactive backend for server environments
import matplotlib.pyplot as plt
from PIL import Image
from typing import Tuple, Dict, Any
import io
import base64
import math
class PatchAttentionAnalyzer:
"""Utility class for computing and visualizing patch-level attention between images."""
def __init__(self, embedding_model):
self.embedding_model = embedding_model
self.supports_native_attention = hasattr(embedding_model, 'supports_native_attention') and embedding_model.supports_native_attention()
def compute_patch_similarities(self, query_image: Image.Image, candidate_image: Image.Image) -> Dict[str, Any]:
"""
Compute patch-level similarities between query and candidate images.
Automatically uses native attention if model supports it.
Returns:
Dictionary containing attention matrix, top correspondences, and metadata
"""
# Use native attention if available
if self.supports_native_attention:
return self.compute_native_attention_similarities(query_image, candidate_image)
# Fallback to cosine similarity approach
try:
# Get patch features for both images
query_patches = self.embedding_model.encode_image_patches(query_image)
candidate_patches = self.embedding_model.encode_image_patches(candidate_image)
# Compute attention matrix
attention_matrix = self.embedding_model.compute_patch_attention(query_patches, candidate_patches)
# Get grid dimensions (assuming square patches for ViT models)
query_grid_size = int(math.sqrt(query_patches.shape[0]))
candidate_grid_size = int(math.sqrt(candidate_patches.shape[0]))
# Find top correspondences for each query patch
top_correspondences = []
for i in range(attention_matrix.shape[0]):
patch_similarities = attention_matrix[i]
top_indices = torch.topk(patch_similarities, k=min(5, patch_similarities.shape[0]))
top_correspondences.append({
'query_patch_idx': i,
'query_patch_coord': self._patch_idx_to_coord(i, query_grid_size),
'top_candidate_indices': top_indices.indices.tolist(),
'top_candidate_coords': [self._patch_idx_to_coord(idx.item(), candidate_grid_size)
for idx in top_indices.indices],
'similarity_scores': top_indices.values.tolist()
})
return {
'attention_matrix': attention_matrix.cpu().numpy(),
'query_grid_size': query_grid_size,
'candidate_grid_size': candidate_grid_size,
'top_correspondences': top_correspondences,
'query_patches_shape': query_patches.shape,
'candidate_patches_shape': candidate_patches.shape,
'overall_similarity': torch.mean(attention_matrix).item()
}
except NotImplementedError:
raise ValueError(f"Patch-level encoding not supported for {self.embedding_model.get_model_name()}")
except Exception as e:
raise RuntimeError(f"Error computing patch similarities: {e}")
def _patch_idx_to_coord(self, patch_idx: int, grid_size: int) -> Tuple[int, int]:
"""Convert flat patch index to (row, col) coordinate."""
row = patch_idx // grid_size
col = patch_idx % grid_size
return (row, col)
def visualize_attention_heatmap(self, query_image: Image.Image, candidate_image: Image.Image,
similarity_data: Dict[str, Any], figsize: Tuple[int, int] = (15, 10)) -> str:
"""
Create a visualization showing attention heatmap between patches.
Returns base64 encoded PNG image.
"""
attention_matrix = similarity_data['attention_matrix']
query_grid_size = similarity_data['query_grid_size']
candidate_grid_size = similarity_data['candidate_grid_size']
fig, axes = plt.subplots(2, 2, figsize=figsize)
fig.suptitle(f'Patch Attention Analysis - Overall Similarity: {similarity_data["overall_similarity"]:.3f}',
fontsize=14, fontweight='bold')
# Plot original images
axes[0, 0].imshow(query_image)
axes[0, 0].set_title('Query Image')
axes[0, 0].axis('off')
self._overlay_patch_grid(axes[0, 0], query_image.size, query_grid_size)
axes[0, 1].imshow(candidate_image)
axes[0, 1].set_title('Candidate Image')
axes[0, 1].axis('off')
self._overlay_patch_grid(axes[0, 1], candidate_image.size, candidate_grid_size)
# Plot attention matrix
im = axes[1, 0].imshow(attention_matrix, cmap='viridis', aspect='auto')
axes[1, 0].set_title('Attention Matrix')
axes[1, 0].set_xlabel('Candidate Patches')
axes[1, 0].set_ylabel('Query Patches')
plt.colorbar(im, ax=axes[1, 0], fraction=0.046, pad=0.04)
# Plot attention summary (max attention per query patch)
max_attention_per_query = np.max(attention_matrix, axis=1)
attention_grid = max_attention_per_query.reshape(query_grid_size, query_grid_size)
im2 = axes[1, 1].imshow(attention_grid, cmap='hot', interpolation='nearest')
axes[1, 1].set_title('Max Attention per Query Patch')
axes[1, 1].set_xlabel('Patch Column')
axes[1, 1].set_ylabel('Patch Row')
plt.colorbar(im2, ax=axes[1, 1], fraction=0.046, pad=0.04)
plt.tight_layout()
# Convert to base64
buffer = io.BytesIO()
plt.savefig(buffer, format='png', dpi=150, bbox_inches='tight')
buffer.seek(0)
plot_data = buffer.getvalue()
buffer.close()
plt.close()
return base64.b64encode(plot_data).decode()
def visualize_top_correspondences(self, query_image: Image.Image, candidate_image: Image.Image,
similarity_data: Dict[str, Any], num_top_patches: int = 6) -> str:
"""
Visualize the top corresponding patches between query and candidate images.
Returns base64 encoded PNG image.
"""
top_correspondences = similarity_data['top_correspondences']
query_grid_size = similarity_data['query_grid_size']
candidate_grid_size = similarity_data['candidate_grid_size']
# Sort by best similarity score
sorted_correspondences = sorted(
top_correspondences,
key=lambda x: max(x['similarity_scores']),
reverse=True
)[:num_top_patches]
fig, axes = plt.subplots(2, num_top_patches, figsize=(3*num_top_patches, 6))
fig.suptitle('Top Patch Correspondences', fontsize=14, fontweight='bold')
for i, correspondence in enumerate(sorted_correspondences):
query_coord = correspondence['query_patch_coord']
best_candidate_coord = correspondence['top_candidate_coords'][0]
best_score = correspondence['similarity_scores'][0]
# Extract and show query patch
query_patch = self._extract_patch_from_image(query_image, query_coord, query_grid_size)
axes[0, i].imshow(query_patch)
axes[0, i].set_title(f'Q-Patch {query_coord}\nScore: {best_score:.3f}')
axes[0, i].axis('off')
# Extract and show best matching candidate patch
candidate_patch = self._extract_patch_from_image(candidate_image, best_candidate_coord, candidate_grid_size)
axes[1, i].imshow(candidate_patch)
axes[1, i].set_title(f'C-Patch {best_candidate_coord}')
axes[1, i].axis('off')
plt.tight_layout()
# Convert to base64
buffer = io.BytesIO()
plt.savefig(buffer, format='png', dpi=150, bbox_inches='tight')
buffer.seek(0)
plot_data = buffer.getvalue()
buffer.close()
plt.close()
return base64.b64encode(plot_data).decode()
def _overlay_patch_grid(self, ax, image_size: Tuple[int, int], grid_size: int):
"""Overlay patch grid lines on image."""
width, height = image_size
patch_width = width / grid_size
patch_height = height / grid_size
# Draw vertical lines
for i in range(1, grid_size):
x = i * patch_width
ax.axvline(x=x, color='white', alpha=0.5, linewidth=1)
# Draw horizontal lines
for i in range(1, grid_size):
y = i * patch_height
ax.axhline(y=y, color='white', alpha=0.5, linewidth=1)
def _extract_patch_from_image(self, image: Image.Image, patch_coord: Tuple[int, int], grid_size: int) -> Image.Image:
"""Extract a specific patch from an image based on grid coordinates."""
row, col = patch_coord
width, height = image.size
patch_width = width // grid_size
patch_height = height // grid_size
left = col * patch_width
top = row * patch_height
right = min((col + 1) * patch_width, width)
bottom = min((row + 1) * patch_height, height)
return image.crop((left, top, right, bottom))
def compute_native_attention_similarities(self, query_image: Image.Image, candidate_image: Image.Image) -> Dict[str, Any]:
"""
Compute patch-level similarities using native attention mechanism.
Only available for models with native attention support (e.g., DINOv2 with registers).
Returns:
Dictionary containing attention matrix, top correspondences, and metadata
"""
try:
# Use model's cross-attention computation
attention_matrix = self.embedding_model.compute_cross_attention(query_image, candidate_image)
attention_matrix_np = attention_matrix.cpu().numpy()
# Get patch counts (attention_matrix is already query_patches x candidate_patches)
num_query_patches = attention_matrix.shape[0]
num_candidate_patches = attention_matrix.shape[1]
# Get grid dimensions (assuming square patches)
query_grid_size = int(math.sqrt(num_query_patches))
candidate_grid_size = int(math.sqrt(num_candidate_patches))
# Find top correspondences for each query patch
top_correspondences = []
for i in range(num_query_patches):
patch_similarities = attention_matrix[i]
top_indices = torch.topk(patch_similarities, k=min(5, num_candidate_patches))
top_correspondences.append({
'query_patch_idx': i,
'query_patch_coord': self._patch_idx_to_coord(i, query_grid_size),
'top_candidate_indices': top_indices.indices.tolist(),
'top_candidate_coords': [self._patch_idx_to_coord(idx.item(), candidate_grid_size)
for idx in top_indices.indices],
'similarity_scores': top_indices.values.tolist()
})
return {
'attention_matrix': attention_matrix_np,
'query_grid_size': query_grid_size,
'candidate_grid_size': candidate_grid_size,
'top_correspondences': top_correspondences,
'query_patches_shape': (num_query_patches, attention_matrix.shape[-1]),
'candidate_patches_shape': (num_candidate_patches, attention_matrix.shape[-1]),
'overall_similarity': torch.mean(attention_matrix).item(),
'use_native_attention': True
}
except Exception as e:
raise RuntimeError(f"Error computing native attention similarities: {e}")
def get_similarity_summary(self, similarity_data: Dict[str, Any]) -> Dict[str, Any]:
"""Get a summary of similarity statistics."""
attention_matrix = similarity_data['attention_matrix']
summary = {
'overall_similarity': similarity_data['overall_similarity'],
'max_similarity': float(np.max(attention_matrix)),
'min_similarity': float(np.min(attention_matrix)),
'std_similarity': float(np.std(attention_matrix)),
'query_patches_count': similarity_data['query_patches_shape'][0],
'candidate_patches_count': similarity_data['candidate_patches_shape'][0],
'high_attention_patches': int(np.sum(attention_matrix > (np.mean(attention_matrix) + np.std(attention_matrix)))),
'model_name': self.embedding_model.get_model_name()
}
# Add native attention flag if present
if 'use_native_attention' in similarity_data:
summary['use_native_attention'] = similarity_data['use_native_attention']
return summary
def visualize_multihead_attention(self, image: Image.Image, layer_idx: int = -1, figsize: Tuple[int, int] = (20, 12)) -> str:
"""
Visualize attention from multiple heads for a single image.
Only available for models with native attention support.
Args:
image: Input image to visualize attention for
layer_idx: Which transformer layer to visualize (-1 for last layer)
figsize: Figure size for the plot
Returns:
Base64 encoded PNG image showing multi-head attention patterns
"""
if not self.supports_native_attention:
raise ValueError("Multi-head attention visualization requires native attention support")
try:
# Get attention maps from the model
attention_maps = self.embedding_model.get_attention_maps(image)
# Shape: (num_layers, num_heads, num_tokens, num_tokens)
# Select the specified layer
layer_attention = attention_maps[layer_idx] # (num_heads, num_tokens, num_tokens)
num_heads = layer_attention.shape[0]
# Extract patch-to-patch attention (exclude CLS token and register tokens)
# Token sequence structure varies by model:
# DINOv2 with registers: [CLS] + 4 register tokens + 256 spatial patches = 261 total
# DINOv3: [CLS] + 4 register tokens + 196 spatial patches (16x16 patches) = 201 total
model_name = self.embedding_model.get_model_name().lower()
if 'dinov3' in model_name:
num_register_tokens = 4
expected_patches = 196 # For 224x224 image with patch size 16 (14*14=196)
else:
num_register_tokens = 4
expected_patches = 256 # For 224x224 image with patch size 14
# Skip CLS token (position 0) and register tokens (positions 1-4)
start_idx = 1 + num_register_tokens # Position 5
end_idx = start_idx + expected_patches # Position 261
patch_attention = layer_attention[:, start_idx:end_idx, start_idx:end_idx]
# Convert to numpy
patch_attention_np = patch_attention.cpu().numpy()
# Get grid size
num_patches = patch_attention.shape[1]
grid_size = int(math.sqrt(num_patches))
# Create subplot grid
num_cols = 4
num_rows = (num_heads + num_cols - 1) // num_cols # Ceiling division
fig, axes = plt.subplots(num_rows, num_cols, figsize=figsize)
axes = axes.flatten() if num_heads > 1 else [axes]
layer_name = f"Layer {layer_idx}" if layer_idx >= 0 else f"Last Layer ({len(attention_maps)})"
fig.suptitle(f'Multi-Head Attention Patterns - {layer_name}', fontsize=16, fontweight='bold')
# Plot each head's average attention
for head_idx in range(num_heads):
# Average attention from all query patches to all key patches
head_attn = patch_attention_np[head_idx]
avg_attention = np.mean(head_attn, axis=0).reshape(grid_size, grid_size)
im = axes[head_idx].imshow(avg_attention, cmap='viridis', interpolation='nearest')
axes[head_idx].set_title(f'Head {head_idx + 1}')
axes[head_idx].axis('off')
plt.colorbar(im, ax=axes[head_idx], fraction=0.046, pad=0.04)
# Hide unused subplots
for idx in range(num_heads, len(axes)):
axes[idx].axis('off')
plt.tight_layout()
# Convert to base64
buffer = io.BytesIO()
plt.savefig(buffer, format='png', dpi=150, bbox_inches='tight')
buffer.seek(0)
plot_data = buffer.getvalue()
buffer.close()
plt.close()
return base64.b64encode(plot_data).decode()
except Exception as e:
raise RuntimeError(f"Error visualizing multi-head attention: {e}")
def visualize_attention_comparison(self, query_image: Image.Image, candidate_image: Image.Image,
figsize: Tuple[int, int] = (20, 10)) -> str:
"""
Compare native attention vs computed cosine similarity side-by-side.
Only available for models with native attention support.
Args:
query_image: Query image
candidate_image: Candidate image
figsize: Figure size for the plot
Returns:
Base64 encoded PNG showing both attention methods
"""
if not self.supports_native_attention:
raise ValueError("Attention comparison requires native attention support")
try:
# Compute native attention
native_data = self.compute_native_attention_similarities(query_image, candidate_image)
# Compute cosine similarity for comparison
query_patches = self.embedding_model.encode_image_patches(query_image)
candidate_patches = self.embedding_model.encode_image_patches(candidate_image)
cosine_attention = self.embedding_model.compute_patch_attention(query_patches, candidate_patches)
cosine_attention_np = cosine_attention.cpu().numpy()
# Create comparison visualization
fig, axes = plt.subplots(2, 3, figsize=figsize)
fig.suptitle('Native Attention vs Cosine Similarity Comparison', fontsize=16, fontweight='bold')
# Row 1: Native attention
axes[0, 0].imshow(query_image)
axes[0, 0].set_title('Query Image')
axes[0, 0].axis('off')
im1 = axes[0, 1].imshow(native_data['attention_matrix'], cmap='viridis', aspect='auto')
axes[0, 1].set_title(f'Native Attention\n(Avg: {native_data["overall_similarity"]:.3f})')
axes[0, 1].set_xlabel('Candidate Patches')
axes[0, 1].set_ylabel('Query Patches')
plt.colorbar(im1, ax=axes[0, 1], fraction=0.046, pad=0.04)
# Max attention heatmap for native
max_native = np.max(native_data['attention_matrix'], axis=1)
native_grid = max_native.reshape(native_data['query_grid_size'], native_data['query_grid_size'])
im2 = axes[0, 2].imshow(native_grid, cmap='hot', interpolation='nearest')
axes[0, 2].set_title('Max Native Attention per Patch')
plt.colorbar(im2, ax=axes[0, 2], fraction=0.046, pad=0.04)
# Row 2: Cosine similarity
axes[1, 0].imshow(candidate_image)
axes[1, 0].set_title('Candidate Image')
axes[1, 0].axis('off')
cosine_mean = float(np.mean(cosine_attention_np))
im3 = axes[1, 1].imshow(cosine_attention_np, cmap='viridis', aspect='auto')
axes[1, 1].set_title(f'Cosine Similarity\n(Avg: {cosine_mean:.3f})')
axes[1, 1].set_xlabel('Candidate Patches')
axes[1, 1].set_ylabel('Query Patches')
plt.colorbar(im3, ax=axes[1, 1], fraction=0.046, pad=0.04)
# Max attention heatmap for cosine
max_cosine = np.max(cosine_attention_np, axis=1)
query_grid_size = int(math.sqrt(query_patches.shape[0]))
cosine_grid = max_cosine.reshape(query_grid_size, query_grid_size)
im4 = axes[1, 2].imshow(cosine_grid, cmap='hot', interpolation='nearest')
axes[1, 2].set_title('Max Cosine Similarity per Patch')
plt.colorbar(im4, ax=axes[1, 2], fraction=0.046, pad=0.04)
plt.tight_layout()
# Convert to base64
buffer = io.BytesIO()
plt.savefig(buffer, format='png', dpi=150, bbox_inches='tight')
buffer.seek(0)
plot_data = buffer.getvalue()
buffer.close()
plt.close()
return base64.b64encode(plot_data).decode()
except Exception as e:
raise RuntimeError(f"Error comparing attention methods: {e}")