Onur Çopur commited on
Commit
0647d62
·
1 Parent(s): 8880ccb

add dinov3 and dinov2 with registers

Browse files
Files changed (4) hide show
  1. .gitignore +7 -1
  2. CLAUDE.md +203 -0
  3. embeddings.py +278 -1
  4. patch_attention.py +238 -2
.gitignore CHANGED
@@ -108,4 +108,10 @@ jspm_packages/
108
 
109
  # temporary folders
110
  tmp/
111
- temp/
 
 
 
 
 
 
 
108
 
109
  # temporary folders
110
  tmp/
111
+ temp/
112
+ *.png
113
+ *.jpg
114
+ *.jpeg
115
+ *.gif
116
+ *.svg
117
+ *.mp4
CLAUDE.md ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # CLAUDE.md
2
+
3
+ This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.
4
+
5
+ ## Overview
6
+
7
+ This is an AI-powered tattoo search engine that combines visual similarity search with image captioning. Users upload a tattoo image, and the system finds visually similar tattoos from across the web using multi-model embeddings and multi-platform search.
8
+
9
+ **Tech Stack**: FastAPI, PyTorch, HuggingFace Transformers, OpenCLIP, DINOv2, SigLIP
10
+
11
+ **Deployment**: Dockerized application designed for HuggingFace Spaces (GPU recommended)
12
+
13
+ ## Development Commands
14
+
15
+ ### Running the Application
16
+
17
+ ```bash
18
+ # Local development
19
+ python app.py
20
+
21
+ # Docker build and run
22
+ docker build -t tattoo-search .
23
+ docker run -p 7860:7860 --env-file .env tattoo-search
24
+ ```
25
+
26
+ ### Environment Setup
27
+
28
+ Required environment variable:
29
+ - `HF_TOKEN`: HuggingFace API token (required for GLM-4.5V captioning via Novita provider)
30
+
31
+ Create `.env` file:
32
+ ```
33
+ HF_TOKEN=your_token_here
34
+ ```
35
+
36
+ ### Testing Endpoints
37
+
38
+ ```bash
39
+ # Health check
40
+ curl http://localhost:7860/health
41
+
42
+ # Get available models
43
+ curl http://localhost:7860/models
44
+
45
+ # Search with image
46
+ curl -X POST http://localhost:7860/search \
47
48
+ -F "embedding_model=clip" \
49
+ -F "include_patch_attention=false"
50
+ ```
51
+
52
+ ## Architecture
53
+
54
+ ### Core Pipeline Flow
55
+
56
+ 1. **Image Upload** → FastAPI endpoint (`/search` in main.py)
57
+ 2. **Caption Generation** → GLM-4.5V via HuggingFace InferenceClient (Novita provider)
58
+ 3. **Multi-Platform Search** → SearchEngineManager coordinates searches across Pinterest, Reddit, Instagram
59
+ 4. **URL Validation** → URLValidator filters valid/accessible images
60
+ 5. **Embedding Extraction** → Selected model (CLIP/DINOv2/SigLIP) encodes query + candidates
61
+ 6. **Similarity Computation** → Cosine similarity ranking in parallel
62
+ 7. **Optional Patch Analysis** → PatchAttentionAnalyzer for detailed visual correspondence
63
+
64
+ ### Key Components
65
+
66
+ **main.py - TattooSearchEngine Class**
67
+ - Main orchestration class that ties all components together
68
+ - `generate_caption()`: Uses HuggingFace InferenceClient with GLM-4.5V model
69
+ - `search_images()`: Delegates to SearchEngineManager with caching
70
+ - `download_and_process_image()`: Parallel image download and similarity computation
71
+ - `compute_similarity()`: ThreadPoolExecutor for concurrent processing with early stopping
72
+
73
+ **embeddings.py - Model Abstraction**
74
+ - `EmbeddingModel`: Abstract base class defining interface
75
+ - `CLIPEmbedding`: OpenAI CLIP ViT-B/32 (default)
76
+ - `DINOv2Embedding`: Meta's self-supervised vision transformer
77
+ - `SigLIPEmbedding`: Google's improved CLIP-like model
78
+ - `EmbeddingModelFactory`: Factory pattern for model instantiation with fallback
79
+ - All models support both global image embeddings and patch-level features
80
+
81
+ **search_engines/ - Multi-Platform Search**
82
+ - `SearchEngineManager`: Coordinates parallel searches across platforms with fallback strategies
83
+ - `BaseSearchEngine`: Abstract interface for platform-specific engines
84
+ - Platform implementations: PinterestSearchEngine, RedditSearchEngine, InstagramSearchEngine
85
+ - `SearchResult` and `ImageResult`: Data classes for structured results
86
+ - Includes intelligent query simplification for fallback searches
87
+
88
+ **patch_attention.py - Visual Correspondence**
89
+ - `PatchAttentionAnalyzer`: Computes patch-level attention matrices between images
90
+ - `compute_patch_similarities()`: Extracts patch features and computes attention
91
+ - `visualize_attention_heatmap()`: Creates matplotlib visualizations as base64 PNG
92
+ - Returns attention matrices showing which image regions correspond best
93
+
94
+ **utils/ - Supporting Utilities**
95
+ - `SearchCache`: In-memory LRU cache with TTL for search results
96
+ - `URLValidator`: Concurrent URL validation to filter broken/inaccessible images
97
+
98
+ ### Model Selection Logic
99
+
100
+ The search engine supports dynamic model switching via `get_search_engine()`:
101
+ - Global singleton pattern with lazy initialization
102
+ - Models are swapped only when a different embedding model is requested
103
+ - Each model implements both global pooling and patch-level encoding
104
+
105
+ ### Search Strategy
106
+
107
+ SearchEngineManager uses a tiered approach:
108
+ 1. Primary platforms (Pinterest, Reddit) searched first
109
+ 2. If results < threshold, try additional platforms (Instagram)
110
+ 3. If still insufficient, simplify query and retry
111
+ 4. All platform searches run concurrently via ThreadPoolExecutor
112
+
113
+ ### Caching Strategy
114
+
115
+ - Search results cached by query + max_results hash
116
+ - Default TTL: 1 hour (3600s)
117
+ - Max cache size: 1000 entries with LRU eviction
118
+ - Significantly reduces redundant searches
119
+
120
+ ## Important Implementation Details
121
+
122
+ ### Caption Generation
123
+ - Uses GLM-4.5V via HuggingFace InferenceClient with Novita provider
124
+ - Converts PIL image to base64 data URL
125
+ - Expects JSON response with "search_query" field
126
+ - Fallback to "tattoo artwork" on failure
127
+
128
+ ### Image Download Headers
129
+ - Platform-specific headers (Pinterest, Instagram optimizations)
130
+ - Random user agent rotation
131
+ - Content-type and size validation (10MB limit, min 50x50px)
132
+ - Exponential backoff retry mechanism
133
+
134
+ ### Similarity Computation
135
+ - Early stopping optimization: stops at 20 good results (5 if patch attention enabled)
136
+ - ThreadPoolExecutor with max 10 workers
137
+ - Rate limiting with 0.1s delays between downloads
138
+ - Future cancellation after target reached
139
+
140
+ ### Patch Attention
141
+ - Only triggered when `include_patch_attention=true`
142
+ - Computes NxM attention matrix (query patches × candidate patches)
143
+ - Visualizations include: attention heatmap, patch grid overlays, top correspondences
144
+ - Returns base64-encoded PNG images
145
+
146
+ ## API Response Structures
147
+
148
+ **POST /search** returns:
149
+ ```json
150
+ {
151
+ "caption": "string",
152
+ "results": [
153
+ {
154
+ "score": 0.95,
155
+ "url": "https://...",
156
+ "patch_attention": { // optional
157
+ "overall_similarity": 0.87,
158
+ "query_grid_size": 7,
159
+ "candidate_grid_size": 7,
160
+ "attention_summary": {...}
161
+ }
162
+ }
163
+ ],
164
+ "embedding_model": "CLIP-ViT-B-32",
165
+ "patch_attention_enabled": false
166
+ }
167
+ ```
168
+
169
+ **POST /analyze-attention** returns detailed patch analysis with visualizations
170
+
171
+ ## Common Development Patterns
172
+
173
+ ### Adding a New Embedding Model
174
+
175
+ 1. Create new class in `embeddings.py` inheriting from `EmbeddingModel`
176
+ 2. Implement `load_model()`, `encode_image()`, `encode_image_patches()`, `get_model_name()`
177
+ 3. Add to `EmbeddingModelFactory.AVAILABLE_MODELS`
178
+ 4. Add config to `get_default_model_configs()`
179
+
180
+ ### Adding a New Search Platform
181
+
182
+ 1. Create new engine in `search_engines/` inheriting from `BaseSearchEngine`
183
+ 2. Add platform to `SearchPlatform` enum in `base.py`
184
+ 3. Implement `search()` and `is_valid_url()` methods
185
+ 4. Add to `SearchEngineManager.engines` dict
186
+ 5. Update platform prioritization in `search_with_fallback()` if needed
187
+
188
+ ## Performance Considerations
189
+
190
+ - GPU acceleration used if available (CUDA)
191
+ - Concurrent image downloads (ThreadPoolExecutor)
192
+ - Search result caching to reduce API calls
193
+ - Early stopping in similarity computation
194
+ - Future cancellation after targets met
195
+ - Model instances reused globally to avoid reloading
196
+
197
+ ## Deployment Notes
198
+
199
+ - Designed for HuggingFace Spaces with Docker SDK
200
+ - Port 7860 (HF Spaces default)
201
+ - Recommended hardware: T4 Small GPU or higher
202
+ - Health check endpoint at `/health` for monitoring
203
+ - All models download on first use and cache in `/app/cache`
embeddings.py CHANGED
@@ -216,6 +216,273 @@ class DINOv2Embedding(EmbeddingModel):
216
  return f"DINOv2-{self.model_name}"
217
 
218
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
219
  class SigLIPEmbedding(EmbeddingModel):
220
  """SigLIP-based embedding model."""
221
 
@@ -297,6 +564,8 @@ class EmbeddingModelFactory:
297
  AVAILABLE_MODELS = {
298
  "clip": CLIPEmbedding,
299
  "dinov2": DINOv2Embedding,
 
 
300
  "siglip": SigLIPEmbedding,
301
  }
302
 
@@ -305,7 +574,7 @@ class EmbeddingModelFactory:
305
  """Create an embedding model instance.
306
 
307
  Args:
308
- model_type: Type of model ('clip', 'dinov2', 'siglip')
309
  device: PyTorch device
310
  **kwargs: Additional arguments for specific models
311
 
@@ -345,6 +614,14 @@ def get_default_model_configs() -> Dict[str, Dict[str, Any]]:
345
  "model_name": "dinov2_vitb14",
346
  "description": "Meta DINOv2 - self-supervised vision transformer, good for visual features"
347
  },
 
 
 
 
 
 
 
 
348
  "siglip": {
349
  "model_name": "google/siglip-base-patch16-224",
350
  "description": "Google SigLIP - improved CLIP-like model with better training"
 
216
  return f"DINOv2-{self.model_name}"
217
 
218
 
219
+ class DINOv2WithRegistersEmbedding(EmbeddingModel):
220
+ """DINOv2 with register tokens - improved feature maps and attention."""
221
+
222
+ def __init__(self, device: torch.device, model_name: str = "facebook/dinov2-with-registers-base"):
223
+ super().__init__(device)
224
+ self.model_name = model_name
225
+ self.processor = None
226
+ self.load_model()
227
+
228
+ def load_model(self) -> None:
229
+ """Load DINOv2 with registers model and preprocessing."""
230
+ try:
231
+ from transformers import Dinov2WithRegistersModel, AutoImageProcessor
232
+
233
+ logger.info(f"Loading DINOv2 with registers model: {self.model_name}")
234
+
235
+ self.model = Dinov2WithRegistersModel.from_pretrained(self.model_name)
236
+ self.model.to(self.device)
237
+ self.model.eval()
238
+
239
+ self.processor = AutoImageProcessor.from_pretrained(self.model_name)
240
+
241
+ logger.info(f"DINOv2 with registers model {self.model_name} loaded successfully")
242
+ except Exception as e:
243
+ logger.error(f"Failed to load DINOv2 with registers model: {e}")
244
+ raise
245
+
246
+ def encode_image(self, image: Image.Image) -> torch.Tensor:
247
+ """Encode image using DINOv2 with registers."""
248
+ try:
249
+ inputs = self.processor(images=image, return_tensors="pt")
250
+ inputs = {k: v.to(self.device) for k, v in inputs.items()}
251
+
252
+ with torch.no_grad():
253
+ outputs = self.model(**inputs)
254
+ # Use pooler_output for global representation, fallback to mean pooling
255
+ if hasattr(outputs, 'pooler_output') and outputs.pooler_output is not None:
256
+ features = outputs.pooler_output
257
+ else:
258
+ # Mean pooling over spatial dimensions
259
+ features = outputs.last_hidden_state.mean(dim=1)
260
+
261
+ features = F.normalize(features, p=2, dim=1)
262
+
263
+ return features
264
+ except Exception as e:
265
+ logger.error(f"Failed to encode image with DINOv2 with registers: {e}")
266
+ raise
267
+
268
+ def encode_image_patches(self, image: Image.Image) -> torch.Tensor:
269
+ """Encode image patches using DINOv2 with registers."""
270
+ try:
271
+ inputs = self.processor(images=image, return_tensors="pt")
272
+ inputs = {k: v.to(self.device) for k, v in inputs.items()}
273
+
274
+ with torch.no_grad():
275
+ outputs = self.model(**inputs)
276
+ # Token sequence structure: [CLS] + 4 register tokens + 256 patch tokens = 261 total
277
+ # We want only the spatial patch tokens (positions 5 to 260)
278
+
279
+ num_register_tokens = 4
280
+ expected_patches = (224 // 14) ** 2 # 256 for base model with 224x224 input, patch size 14
281
+
282
+ # Skip CLS token (position 0) and register tokens (positions 1-4)
283
+ start_idx = 1 + num_register_tokens # Position 5
284
+ end_idx = start_idx + expected_patches # Position 261
285
+
286
+ patch_features = outputs.last_hidden_state[:, start_idx:end_idx, :] # [1, 256, feature_dim]
287
+
288
+ # Normalize patch features
289
+ patch_features = F.normalize(patch_features, p=2, dim=-1)
290
+
291
+ return patch_features.squeeze(0) # [num_patches, feature_dim]
292
+
293
+ except Exception as e:
294
+ logger.error(f"Failed to encode image patches with DINOv2 with registers: {e}")
295
+ raise
296
+
297
+ def get_model_name(self) -> str:
298
+ return f"DINOv2-WithRegisters-{self.model_name.split('/')[-1]}"
299
+
300
+ def get_attention_maps(self, image: Image.Image) -> torch.Tensor:
301
+ """
302
+ Extract native attention maps from DINOv2 with registers.
303
+
304
+ Returns:
305
+ Attention tensor with shape (num_layers, num_heads, num_tokens, num_tokens)
306
+ where num_tokens includes [CLS] + patches + registers
307
+ """
308
+ try:
309
+ inputs = self.processor(images=image, return_tensors="pt")
310
+ inputs = {k: v.to(self.device) for k, v in inputs.items()}
311
+
312
+ with torch.no_grad():
313
+ outputs = self.model(**inputs, output_attentions=True)
314
+ # outputs.attentions is a tuple of attention tensors, one per layer
315
+ # Each has shape: (batch_size, num_heads, sequence_length, sequence_length)
316
+
317
+ # Stack all layer attentions
318
+ attention_stack = torch.stack(outputs.attentions) # (num_layers, batch_size, num_heads, seq_len, seq_len)
319
+ attention_stack = attention_stack.squeeze(1) # Remove batch dimension -> (num_layers, num_heads, seq_len, seq_len)
320
+
321
+ return attention_stack
322
+
323
+ except Exception as e:
324
+ logger.error(f"Failed to extract attention maps: {e}")
325
+ raise
326
+
327
+ def compute_cross_attention(self, query_image: Image.Image, candidate_image: Image.Image) -> torch.Tensor:
328
+ """
329
+ Compute cross-attention between query and candidate images using patch features.
330
+
331
+ This uses the extracted patch embeddings to compute attention from query to candidate,
332
+ similar to the native attention mechanism but across two images.
333
+
334
+ Returns:
335
+ Cross-attention matrix with shape (query_patches, candidate_patches)
336
+ """
337
+ try:
338
+ # Get patch features for both images
339
+ query_patches = self.encode_image_patches(query_image) # (num_query_patches, feature_dim)
340
+ candidate_patches = self.encode_image_patches(candidate_image) # (num_candidate_patches, feature_dim)
341
+
342
+ # Compute attention-style similarity (softmax over candidate dimension)
343
+ # attention[i,j] = how much query patch i attends to candidate patch j
344
+ attention_logits = torch.mm(query_patches, candidate_patches.T) # (query_patches, candidate_patches)
345
+
346
+ # Apply softmax to get attention distribution for each query patch
347
+ cross_attention = F.softmax(attention_logits, dim=1)
348
+
349
+ return cross_attention
350
+
351
+ except Exception as e:
352
+ logger.error(f"Failed to compute cross-attention: {e}")
353
+ raise
354
+
355
+ def supports_native_attention(self) -> bool:
356
+ """Check if this model supports native attention extraction."""
357
+ return True
358
+
359
+
360
+ class DINOv3Embedding(EmbeddingModel):
361
+ """DINOv3-based embedding model from HuggingFace transformers."""
362
+
363
+ def __init__(self, device: torch.device, model_name: str = "facebook/dinov3-vits16-pretrain-lvd1689m"):
364
+ super().__init__(device)
365
+ self.model_name = model_name
366
+ self.processor = None
367
+ self.load_model()
368
+
369
+ def load_model(self) -> None:
370
+ """Load DINOv3 model and preprocessing."""
371
+ try:
372
+ from transformers import AutoModel, AutoImageProcessor
373
+
374
+ logger.info(f"Loading DINOv3 model: {self.model_name}")
375
+
376
+ self.model = AutoModel.from_pretrained(self.model_name)
377
+ self.model.to(self.device)
378
+ self.model.eval()
379
+
380
+ self.processor = AutoImageProcessor.from_pretrained(self.model_name)
381
+
382
+ logger.info(f"DINOv3 model {self.model_name} loaded successfully")
383
+ except Exception as e:
384
+ logger.error(f"Failed to load DINOv3 model: {e}")
385
+ raise
386
+
387
+ def encode_image(self, image: Image.Image) -> torch.Tensor:
388
+ """Encode image using DINOv3."""
389
+ try:
390
+ inputs = self.processor(images=image, return_tensors="pt")
391
+ inputs = {k: v.to(self.device) for k, v in inputs.items()}
392
+
393
+ with torch.no_grad():
394
+ outputs = self.model(**inputs)
395
+ # Use pooler_output (CLS token) for global representation
396
+ if hasattr(outputs, 'pooler_output') and outputs.pooler_output is not None:
397
+ features = outputs.pooler_output
398
+ else:
399
+ # Fallback to mean pooling over patch embeddings
400
+ features = outputs.last_hidden_state[:, 1:, :].mean(dim=1)
401
+
402
+ features = F.normalize(features, p=2, dim=1)
403
+
404
+ return features
405
+ except Exception as e:
406
+ logger.error(f"Failed to encode image with DINOv3: {e}")
407
+ raise
408
+
409
+ def encode_image_patches(self, image: Image.Image) -> torch.Tensor:
410
+ """Encode image patches using DINOv3."""
411
+ try:
412
+ inputs = self.processor(images=image, return_tensors="pt")
413
+ inputs = {k: v.to(self.device) for k, v in inputs.items()}
414
+
415
+ with torch.no_grad():
416
+ outputs = self.model(**inputs)
417
+ # DINOv3 outputs: [CLS] + register tokens + patch tokens
418
+ # We want only the patch tokens (skip CLS at position 0 and register tokens)
419
+ # For DINOv3-ViTS16, it has 4 register tokens
420
+ num_register_tokens = 4
421
+ patch_features = outputs.last_hidden_state[:, 1 + num_register_tokens:, :]
422
+
423
+ # Normalize patch features
424
+ patch_features = F.normalize(patch_features, p=2, dim=-1)
425
+
426
+ return patch_features.squeeze(0) # [num_patches, feature_dim]
427
+
428
+ except Exception as e:
429
+ logger.error(f"Failed to encode image patches with DINOv3: {e}")
430
+ raise
431
+
432
+ def get_model_name(self) -> str:
433
+ return f"DINOv3-{self.model_name.split('/')[-1]}"
434
+
435
+ def supports_native_attention(self) -> bool:
436
+ """Check if this model supports native attention extraction."""
437
+ return True
438
+
439
+ def get_attention_maps(self, image: Image.Image) -> torch.Tensor:
440
+ """
441
+ Extract native attention maps from DINOv3.
442
+
443
+ Returns:
444
+ Attention tensor with shape (num_layers, num_heads, num_tokens, num_tokens)
445
+ """
446
+ try:
447
+ inputs = self.processor(images=image, return_tensors="pt")
448
+ inputs = {k: v.to(self.device) for k, v in inputs.items()}
449
+
450
+ with torch.no_grad():
451
+ outputs = self.model(**inputs, output_attentions=True)
452
+ # Stack all layer attentions
453
+ attention_stack = torch.stack(outputs.attentions)
454
+ attention_stack = attention_stack.squeeze(1) # Remove batch dimension
455
+
456
+ return attention_stack
457
+
458
+ except Exception as e:
459
+ logger.error(f"Failed to extract attention maps: {e}")
460
+ raise
461
+
462
+ def compute_cross_attention(self, query_image: Image.Image, candidate_image: Image.Image) -> torch.Tensor:
463
+ """
464
+ Compute cross-attention between query and candidate images using patch features.
465
+
466
+ Returns:
467
+ Cross-attention matrix with shape (query_patches, candidate_patches)
468
+ """
469
+ try:
470
+ query_patches = self.encode_image_patches(query_image)
471
+ candidate_patches = self.encode_image_patches(candidate_image)
472
+
473
+ # Compute attention-style similarity
474
+ attention_logits = torch.mm(query_patches, candidate_patches.T)
475
+
476
+ # Apply softmax to get attention distribution
477
+ cross_attention = F.softmax(attention_logits, dim=1)
478
+
479
+ return cross_attention
480
+
481
+ except Exception as e:
482
+ logger.error(f"Failed to compute cross-attention: {e}")
483
+ raise
484
+
485
+
486
  class SigLIPEmbedding(EmbeddingModel):
487
  """SigLIP-based embedding model."""
488
 
 
564
  AVAILABLE_MODELS = {
565
  "clip": CLIPEmbedding,
566
  "dinov2": DINOv2Embedding,
567
+ "dinov2_registers": DINOv2WithRegistersEmbedding,
568
+ "dinov3": DINOv3Embedding,
569
  "siglip": SigLIPEmbedding,
570
  }
571
 
 
574
  """Create an embedding model instance.
575
 
576
  Args:
577
+ model_type: Type of model ('clip', 'dinov2', 'dinov2_registers', 'dinov3', 'siglip')
578
  device: PyTorch device
579
  **kwargs: Additional arguments for specific models
580
 
 
614
  "model_name": "dinov2_vitb14",
615
  "description": "Meta DINOv2 - self-supervised vision transformer, good for visual features"
616
  },
617
+ "dinov2_registers": {
618
+ "model_name": "facebook/dinov2-with-registers-base",
619
+ "description": "Meta DINOv2 with register tokens - improved feature maps and attention"
620
+ },
621
+ "dinov3": {
622
+ "model_name": "facebook/dinov3-vits16-pretrain-lvd1689m",
623
+ "description": "Meta DINOv3 - vision foundation model with high-quality dense features"
624
+ },
625
  "siglip": {
626
  "model_name": "google/siglip-base-patch16-224",
627
  "description": "Google SigLIP - improved CLIP-like model with better training"
patch_attention.py CHANGED
@@ -15,14 +15,21 @@ class PatchAttentionAnalyzer:
15
 
16
  def __init__(self, embedding_model):
17
  self.embedding_model = embedding_model
 
18
 
19
  def compute_patch_similarities(self, query_image: Image.Image, candidate_image: Image.Image) -> Dict[str, Any]:
20
  """
21
  Compute patch-level similarities between query and candidate images.
 
22
 
23
  Returns:
24
  Dictionary containing attention matrix, top correspondences, and metadata
25
  """
 
 
 
 
 
26
  try:
27
  # Get patch features for both images
28
  query_patches = self.embedding_model.encode_image_patches(query_image)
@@ -205,11 +212,61 @@ class PatchAttentionAnalyzer:
205
 
206
  return image.crop((left, top, right, bottom))
207
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
208
  def get_similarity_summary(self, similarity_data: Dict[str, Any]) -> Dict[str, Any]:
209
  """Get a summary of similarity statistics."""
210
  attention_matrix = similarity_data['attention_matrix']
211
 
212
- return {
213
  'overall_similarity': similarity_data['overall_similarity'],
214
  'max_similarity': float(np.max(attention_matrix)),
215
  'min_similarity': float(np.min(attention_matrix)),
@@ -218,4 +275,183 @@ class PatchAttentionAnalyzer:
218
  'candidate_patches_count': similarity_data['candidate_patches_shape'][0],
219
  'high_attention_patches': int(np.sum(attention_matrix > (np.mean(attention_matrix) + np.std(attention_matrix)))),
220
  'model_name': self.embedding_model.get_model_name()
221
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
  def __init__(self, embedding_model):
17
  self.embedding_model = embedding_model
18
+ self.supports_native_attention = hasattr(embedding_model, 'supports_native_attention') and embedding_model.supports_native_attention()
19
 
20
  def compute_patch_similarities(self, query_image: Image.Image, candidate_image: Image.Image) -> Dict[str, Any]:
21
  """
22
  Compute patch-level similarities between query and candidate images.
23
+ Automatically uses native attention if model supports it.
24
 
25
  Returns:
26
  Dictionary containing attention matrix, top correspondences, and metadata
27
  """
28
+ # Use native attention if available
29
+ if self.supports_native_attention:
30
+ return self.compute_native_attention_similarities(query_image, candidate_image)
31
+
32
+ # Fallback to cosine similarity approach
33
  try:
34
  # Get patch features for both images
35
  query_patches = self.embedding_model.encode_image_patches(query_image)
 
212
 
213
  return image.crop((left, top, right, bottom))
214
 
215
+ def compute_native_attention_similarities(self, query_image: Image.Image, candidate_image: Image.Image) -> Dict[str, Any]:
216
+ """
217
+ Compute patch-level similarities using native attention mechanism.
218
+ Only available for models with native attention support (e.g., DINOv2 with registers).
219
+
220
+ Returns:
221
+ Dictionary containing attention matrix, top correspondences, and metadata
222
+ """
223
+ try:
224
+ # Use model's cross-attention computation
225
+ attention_matrix = self.embedding_model.compute_cross_attention(query_image, candidate_image)
226
+ attention_matrix_np = attention_matrix.cpu().numpy()
227
+
228
+ # Get patch counts (attention_matrix is already query_patches x candidate_patches)
229
+ num_query_patches = attention_matrix.shape[0]
230
+ num_candidate_patches = attention_matrix.shape[1]
231
+
232
+ # Get grid dimensions (assuming square patches)
233
+ query_grid_size = int(math.sqrt(num_query_patches))
234
+ candidate_grid_size = int(math.sqrt(num_candidate_patches))
235
+
236
+ # Find top correspondences for each query patch
237
+ top_correspondences = []
238
+ for i in range(num_query_patches):
239
+ patch_similarities = attention_matrix[i]
240
+ top_indices = torch.topk(patch_similarities, k=min(5, num_candidate_patches))
241
+
242
+ top_correspondences.append({
243
+ 'query_patch_idx': i,
244
+ 'query_patch_coord': self._patch_idx_to_coord(i, query_grid_size),
245
+ 'top_candidate_indices': top_indices.indices.tolist(),
246
+ 'top_candidate_coords': [self._patch_idx_to_coord(idx.item(), candidate_grid_size)
247
+ for idx in top_indices.indices],
248
+ 'similarity_scores': top_indices.values.tolist()
249
+ })
250
+
251
+ return {
252
+ 'attention_matrix': attention_matrix_np,
253
+ 'query_grid_size': query_grid_size,
254
+ 'candidate_grid_size': candidate_grid_size,
255
+ 'top_correspondences': top_correspondences,
256
+ 'query_patches_shape': (num_query_patches, attention_matrix.shape[-1]),
257
+ 'candidate_patches_shape': (num_candidate_patches, attention_matrix.shape[-1]),
258
+ 'overall_similarity': torch.mean(attention_matrix).item(),
259
+ 'use_native_attention': True
260
+ }
261
+
262
+ except Exception as e:
263
+ raise RuntimeError(f"Error computing native attention similarities: {e}")
264
+
265
  def get_similarity_summary(self, similarity_data: Dict[str, Any]) -> Dict[str, Any]:
266
  """Get a summary of similarity statistics."""
267
  attention_matrix = similarity_data['attention_matrix']
268
 
269
+ summary = {
270
  'overall_similarity': similarity_data['overall_similarity'],
271
  'max_similarity': float(np.max(attention_matrix)),
272
  'min_similarity': float(np.min(attention_matrix)),
 
275
  'candidate_patches_count': similarity_data['candidate_patches_shape'][0],
276
  'high_attention_patches': int(np.sum(attention_matrix > (np.mean(attention_matrix) + np.std(attention_matrix)))),
277
  'model_name': self.embedding_model.get_model_name()
278
+ }
279
+
280
+ # Add native attention flag if present
281
+ if 'use_native_attention' in similarity_data:
282
+ summary['use_native_attention'] = similarity_data['use_native_attention']
283
+
284
+ return summary
285
+
286
+ def visualize_multihead_attention(self, image: Image.Image, layer_idx: int = -1, figsize: Tuple[int, int] = (20, 12)) -> str:
287
+ """
288
+ Visualize attention from multiple heads for a single image.
289
+ Only available for models with native attention support.
290
+
291
+ Args:
292
+ image: Input image to visualize attention for
293
+ layer_idx: Which transformer layer to visualize (-1 for last layer)
294
+ figsize: Figure size for the plot
295
+
296
+ Returns:
297
+ Base64 encoded PNG image showing multi-head attention patterns
298
+ """
299
+ if not self.supports_native_attention:
300
+ raise ValueError("Multi-head attention visualization requires native attention support")
301
+
302
+ try:
303
+ # Get attention maps from the model
304
+ attention_maps = self.embedding_model.get_attention_maps(image)
305
+ # Shape: (num_layers, num_heads, num_tokens, num_tokens)
306
+
307
+ # Select the specified layer
308
+ layer_attention = attention_maps[layer_idx] # (num_heads, num_tokens, num_tokens)
309
+ num_heads = layer_attention.shape[0]
310
+
311
+ # Extract patch-to-patch attention (exclude CLS token and register tokens)
312
+ # Token sequence structure varies by model:
313
+ # DINOv2 with registers: [CLS] + 4 register tokens + 256 spatial patches = 261 total
314
+ # DINOv3: [CLS] + 4 register tokens + 196 spatial patches (16x16 patches) = 201 total
315
+ model_name = self.embedding_model.get_model_name().lower()
316
+
317
+ if 'dinov3' in model_name:
318
+ num_register_tokens = 4
319
+ expected_patches = 196 # For 224x224 image with patch size 16 (14*14=196)
320
+ else:
321
+ num_register_tokens = 4
322
+ expected_patches = 256 # For 224x224 image with patch size 14
323
+
324
+ # Skip CLS token (position 0) and register tokens (positions 1-4)
325
+ start_idx = 1 + num_register_tokens # Position 5
326
+ end_idx = start_idx + expected_patches # Position 261
327
+ patch_attention = layer_attention[:, start_idx:end_idx, start_idx:end_idx]
328
+
329
+ # Convert to numpy
330
+ patch_attention_np = patch_attention.cpu().numpy()
331
+
332
+ # Get grid size
333
+ num_patches = patch_attention.shape[1]
334
+ grid_size = int(math.sqrt(num_patches))
335
+
336
+ # Create subplot grid
337
+ num_cols = 4
338
+ num_rows = (num_heads + num_cols - 1) // num_cols # Ceiling division
339
+ fig, axes = plt.subplots(num_rows, num_cols, figsize=figsize)
340
+ axes = axes.flatten() if num_heads > 1 else [axes]
341
+
342
+ layer_name = f"Layer {layer_idx}" if layer_idx >= 0 else f"Last Layer ({len(attention_maps)})"
343
+ fig.suptitle(f'Multi-Head Attention Patterns - {layer_name}', fontsize=16, fontweight='bold')
344
+
345
+ # Plot each head's average attention
346
+ for head_idx in range(num_heads):
347
+ # Average attention from all query patches to all key patches
348
+ head_attn = patch_attention_np[head_idx]
349
+ avg_attention = np.mean(head_attn, axis=0).reshape(grid_size, grid_size)
350
+
351
+ im = axes[head_idx].imshow(avg_attention, cmap='viridis', interpolation='nearest')
352
+ axes[head_idx].set_title(f'Head {head_idx + 1}')
353
+ axes[head_idx].axis('off')
354
+ plt.colorbar(im, ax=axes[head_idx], fraction=0.046, pad=0.04)
355
+
356
+ # Hide unused subplots
357
+ for idx in range(num_heads, len(axes)):
358
+ axes[idx].axis('off')
359
+
360
+ plt.tight_layout()
361
+
362
+ # Convert to base64
363
+ buffer = io.BytesIO()
364
+ plt.savefig(buffer, format='png', dpi=150, bbox_inches='tight')
365
+ buffer.seek(0)
366
+ plot_data = buffer.getvalue()
367
+ buffer.close()
368
+ plt.close()
369
+
370
+ return base64.b64encode(plot_data).decode()
371
+
372
+ except Exception as e:
373
+ raise RuntimeError(f"Error visualizing multi-head attention: {e}")
374
+
375
+ def visualize_attention_comparison(self, query_image: Image.Image, candidate_image: Image.Image,
376
+ figsize: Tuple[int, int] = (20, 10)) -> str:
377
+ """
378
+ Compare native attention vs computed cosine similarity side-by-side.
379
+ Only available for models with native attention support.
380
+
381
+ Args:
382
+ query_image: Query image
383
+ candidate_image: Candidate image
384
+ figsize: Figure size for the plot
385
+
386
+ Returns:
387
+ Base64 encoded PNG showing both attention methods
388
+ """
389
+ if not self.supports_native_attention:
390
+ raise ValueError("Attention comparison requires native attention support")
391
+
392
+ try:
393
+ # Compute native attention
394
+ native_data = self.compute_native_attention_similarities(query_image, candidate_image)
395
+
396
+ # Compute cosine similarity for comparison
397
+ query_patches = self.embedding_model.encode_image_patches(query_image)
398
+ candidate_patches = self.embedding_model.encode_image_patches(candidate_image)
399
+ cosine_attention = self.embedding_model.compute_patch_attention(query_patches, candidate_patches)
400
+ cosine_attention_np = cosine_attention.cpu().numpy()
401
+
402
+ # Create comparison visualization
403
+ fig, axes = plt.subplots(2, 3, figsize=figsize)
404
+ fig.suptitle('Native Attention vs Cosine Similarity Comparison', fontsize=16, fontweight='bold')
405
+
406
+ # Row 1: Native attention
407
+ axes[0, 0].imshow(query_image)
408
+ axes[0, 0].set_title('Query Image')
409
+ axes[0, 0].axis('off')
410
+
411
+ im1 = axes[0, 1].imshow(native_data['attention_matrix'], cmap='viridis', aspect='auto')
412
+ axes[0, 1].set_title(f'Native Attention\n(Avg: {native_data["overall_similarity"]:.3f})')
413
+ axes[0, 1].set_xlabel('Candidate Patches')
414
+ axes[0, 1].set_ylabel('Query Patches')
415
+ plt.colorbar(im1, ax=axes[0, 1], fraction=0.046, pad=0.04)
416
+
417
+ # Max attention heatmap for native
418
+ max_native = np.max(native_data['attention_matrix'], axis=1)
419
+ native_grid = max_native.reshape(native_data['query_grid_size'], native_data['query_grid_size'])
420
+ im2 = axes[0, 2].imshow(native_grid, cmap='hot', interpolation='nearest')
421
+ axes[0, 2].set_title('Max Native Attention per Patch')
422
+ plt.colorbar(im2, ax=axes[0, 2], fraction=0.046, pad=0.04)
423
+
424
+ # Row 2: Cosine similarity
425
+ axes[1, 0].imshow(candidate_image)
426
+ axes[1, 0].set_title('Candidate Image')
427
+ axes[1, 0].axis('off')
428
+
429
+ cosine_mean = float(np.mean(cosine_attention_np))
430
+ im3 = axes[1, 1].imshow(cosine_attention_np, cmap='viridis', aspect='auto')
431
+ axes[1, 1].set_title(f'Cosine Similarity\n(Avg: {cosine_mean:.3f})')
432
+ axes[1, 1].set_xlabel('Candidate Patches')
433
+ axes[1, 1].set_ylabel('Query Patches')
434
+ plt.colorbar(im3, ax=axes[1, 1], fraction=0.046, pad=0.04)
435
+
436
+ # Max attention heatmap for cosine
437
+ max_cosine = np.max(cosine_attention_np, axis=1)
438
+ query_grid_size = int(math.sqrt(query_patches.shape[0]))
439
+ cosine_grid = max_cosine.reshape(query_grid_size, query_grid_size)
440
+ im4 = axes[1, 2].imshow(cosine_grid, cmap='hot', interpolation='nearest')
441
+ axes[1, 2].set_title('Max Cosine Similarity per Patch')
442
+ plt.colorbar(im4, ax=axes[1, 2], fraction=0.046, pad=0.04)
443
+
444
+ plt.tight_layout()
445
+
446
+ # Convert to base64
447
+ buffer = io.BytesIO()
448
+ plt.savefig(buffer, format='png', dpi=150, bbox_inches='tight')
449
+ buffer.seek(0)
450
+ plot_data = buffer.getvalue()
451
+ buffer.close()
452
+ plt.close()
453
+
454
+ return base64.b64encode(plot_data).decode()
455
+
456
+ except Exception as e:
457
+ raise RuntimeError(f"Error comparing attention methods: {e}")