AdithyaSK commited on
Commit
db64b10
Β·
1 Parent(s): 4392e56

Refactor app.py: update demo description, enhance PDF handling, and improve model loading functions

Browse files
Files changed (1) hide show
  1. app.py +323 -287
app.py CHANGED
@@ -1,28 +1,22 @@
1
  """
2
- Gradio Demo for Document Retrieval - Hugging Face Spaces with ZeroGPU
3
-
4
- This script creates a Gradio interface for testing both BiGemma3 and ColGemma3 models
5
- with PDF document upload, automatic conversion to images, and query-based retrieval.
6
-
7
- Features:
8
- - PDF upload with automatic conversion to images
9
- - Model selection: NetraEmbed (BiGemma3), ColNetraEmbed (ColGemma3), or Both
10
- - Query input with top-k selection (default: 5)
11
- - Similarity score display
12
- - Side-by-side comparison when both models are selected
13
- - ZeroGPU integration for efficient GPU usage
14
- """
15
 
16
- import io
17
- import gc
18
- import math
19
- from typing import List, Optional, Tuple
 
 
 
20
 
21
- import gradio as gr
22
- import torch
23
  import spaces
 
 
24
  from pdf2image import convert_from_path
25
  from PIL import Image
 
 
 
26
  import matplotlib.pyplot as plt
27
  import numpy as np
28
  import seaborn as sns
@@ -33,8 +27,6 @@ from colpali_engine.models import BiGemma3, BiGemmaProcessor3, ColGemma3, ColGem
33
  from colpali_engine.interpretability import get_similarity_maps_from_embeddings
34
  from colpali_engine.interpretability.similarity_map_utils import normalize_similarity_map
35
 
36
- # Configuration
37
- MAX_BATCH_SIZE = 32 # Maximum pages to process at once
38
  device = "cuda" if torch.cuda.is_available() else "cpu"
39
 
40
  print(f"Device: {device}")
@@ -54,146 +46,144 @@ class DocumentIndex:
54
 
55
  doc_index = DocumentIndex()
56
 
57
- # Helper functions
58
- def pdf_to_images(pdf_path: str) -> List[Image.Image]:
59
- """Convert PDF to list of PIL Images with error handling."""
60
- try:
61
- print(f"Converting PDF to images: {pdf_path}")
62
- images = convert_from_path(pdf_path, dpi=200)
63
- print(f"Converted {len(images)} pages")
64
- return images
65
- except Exception as e:
66
- print(f"❌ PDF conversion error: {str(e)}")
67
- raise gr.Error(f"Failed to convert PDF: {str(e)}")
68
 
69
  @spaces.GPU
70
  def load_bigemma_model():
71
  """Load BiGemma3 model and processor."""
72
  if doc_index.bigemma_model is None:
73
  print("Loading BiGemma3 (NetraEmbed)...")
74
- try:
75
- doc_index.bigemma_processor = BiGemmaProcessor3.from_pretrained(
76
- "Cognitive-Lab/NetraEmbed",
77
- use_fast=True,
78
- )
79
- doc_index.bigemma_model = BiGemma3.from_pretrained(
80
- "Cognitive-Lab/NetraEmbed",
81
- torch_dtype=torch.bfloat16,
82
- device_map=device,
83
- )
84
- doc_index.bigemma_model.eval()
85
- print("βœ“ BiGemma3 loaded successfully")
86
- except Exception as e:
87
- print(f"❌ Failed to load BiGemma3: {str(e)}")
88
- raise gr.Error(f"Failed to load BiGemma3: {str(e)}")
89
- return "βœ… BiGemma3 loaded"
90
 
91
  @spaces.GPU
92
  def load_colgemma_model():
93
  """Load ColGemma3 model and processor."""
94
  if doc_index.colgemma_model is None:
95
  print("Loading ColGemma3 (ColNetraEmbed)...")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
  try:
97
- doc_index.colgemma_model = ColGemma3.from_pretrained(
98
- "Cognitive-Lab/ColNetraEmbed",
99
- dtype=torch.bfloat16,
100
- device_map=device,
101
- )
102
- doc_index.colgemma_model.eval()
103
- doc_index.colgemma_processor = ColGemmaProcessor3.from_pretrained(
104
- "Cognitive-Lab/ColNetraEmbed",
105
- use_fast=True,
106
- )
107
- print("βœ“ ColGemma3 loaded successfully")
108
  except Exception as e:
109
- print(f"❌ Failed to load ColGemma3: {str(e)}")
110
- raise gr.Error(f"Failed to load ColGemma3: {str(e)}")
111
- return "βœ… ColGemma3 loaded"
 
 
 
 
112
 
113
- def unload_models():
114
- """Unload models and free GPU memory."""
115
- try:
116
- if doc_index.bigemma_model is not None:
117
- del doc_index.bigemma_model
118
- del doc_index.bigemma_processor
119
- doc_index.bigemma_model = None
120
- doc_index.bigemma_processor = None
121
-
122
- if doc_index.colgemma_model is not None:
123
- del doc_index.colgemma_model
124
- del doc_index.colgemma_processor
125
- doc_index.colgemma_model = None
126
- doc_index.colgemma_processor = None
127
-
128
- # Clear embeddings and images
129
- doc_index.bigemma_embeddings = None
130
- doc_index.colgemma_embeddings = None
131
- doc_index.images = []
132
-
133
- # Force garbage collection
134
- gc.collect()
135
- if torch.cuda.is_available():
136
- torch.cuda.empty_cache()
137
- torch.cuda.synchronize()
138
-
139
- return "βœ… Models unloaded and GPU memory cleared"
140
- except Exception as e:
141
- return f"❌ Error unloading models: {str(e)}"
142
 
143
  @spaces.GPU
144
- def index_bigemma_images(images: List[Image.Image]) -> torch.Tensor:
145
- """Index images with BiGemma3 model."""
146
- # Ensure model is loaded
147
- if doc_index.bigemma_model is None:
148
- load_bigemma_model()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
149
 
150
- model, processor = doc_index.bigemma_model, doc_index.bigemma_processor
151
- batch_images = processor.process_images(images).to(device)
152
- embeddings = model(**batch_images, embedding_dim=768)
153
- return embeddings
154
 
155
  @spaces.GPU
156
- def index_colgemma_images(images: List[Image.Image]) -> torch.Tensor:
157
- """Index images with ColGemma3 model."""
158
- # Ensure model is loaded
159
- if doc_index.colgemma_model is None:
160
- load_colgemma_model()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
161
 
162
- model, processor = doc_index.colgemma_model, doc_index.colgemma_processor
163
- batch_images = processor.process_images(images).to(device)
164
- embeddings = model(**batch_images)
165
- return embeddings
166
 
167
- def index_document(pdf_file, model_choice: str):
168
- """Upload and index a PDF document."""
169
- if pdf_file is None:
170
- return "⚠️ Please upload a PDF document first."
 
 
 
 
171
 
172
  try:
173
- status = []
174
 
175
- # Convert PDF to images
176
- status.append("⏳ Converting PDF to images...")
177
- doc_index.images = pdf_to_images(pdf_file.name)
 
178
  num_pages = len(doc_index.images)
179
- status.append(f"βœ“ Converted PDF to {num_pages} images")
180
-
181
- if num_pages > MAX_BATCH_SIZE:
182
- status.append(f"⚠️ Large PDF ({num_pages} pages). Processing in batches...")
183
 
184
  # Index with BiGemma3
185
  if model_choice in ["NetraEmbed (BiGemma3)", "Both"]:
186
- status.append("⏳ Loading & encoding with BiGemma3...")
187
  doc_index.bigemma_embeddings = index_bigemma_images(doc_index.images)
188
- status.append(f"βœ“ Indexed with BiGemma3 (shape: {doc_index.bigemma_embeddings.shape})")
189
 
190
  # Index with ColGemma3
191
  if model_choice in ["ColNetraEmbed (ColGemma3)", "Both"]:
192
- status.append("⏳ Loading & encoding with ColGemma3...")
193
  doc_index.colgemma_embeddings = index_colgemma_images(doc_index.images)
194
- status.append(f"βœ“ Indexed with ColGemma3 (shape: {doc_index.colgemma_embeddings.shape})")
195
 
196
- return "\n".join(status) + "\n\nβœ… Document ready for querying!"
 
197
 
198
  except Exception as e:
199
  import traceback
@@ -201,67 +191,59 @@ def index_document(pdf_file, model_choice: str):
201
  print(f"Indexing error: {error_details}")
202
  return f"❌ Error indexing document: {str(e)}"
203
 
 
204
  @spaces.GPU
205
  def generate_colgemma_heatmap(
206
  image: Image.Image,
207
- query: str,
208
  query_embedding: torch.Tensor,
209
  image_embedding: torch.Tensor,
210
- model,
211
- processor,
212
  ) -> Image.Image:
213
  """Generate heatmap overlay for ColGemma3 results."""
214
  try:
215
- # Re-process the single image to get the proper batch_images dict for image mask
 
 
216
  batch_images = processor.process_images([image]).to(device)
217
 
218
- # Create image mask manually
219
  if "input_ids" in batch_images and hasattr(model.config, "image_token_id"):
220
  image_token_id = model.config.image_token_id
221
  image_mask = batch_images["input_ids"] == image_token_id
222
  else:
223
  image_mask = torch.ones(
224
- image_embedding.shape[0], image_embedding.shape[1], dtype=torch.bool, device=device
 
225
  )
226
 
227
- # Calculate n_patches from actual number of image tokens
228
  num_image_tokens = image_mask.sum().item()
229
  n_side = int(math.sqrt(num_image_tokens))
230
-
231
- if n_side * n_side == num_image_tokens:
232
- n_patches = (n_side, n_side)
233
- else:
234
- n_patches = (16, 16)
235
 
236
  # Generate similarity maps
237
  similarity_maps_list = get_similarity_maps_from_embeddings(
238
- image_embeddings=image_embedding,
239
- query_embeddings=query_embedding,
240
  n_patches=n_patches,
241
  image_mask=image_mask,
242
  )
243
 
244
  similarity_map = similarity_maps_list[0]
245
-
246
- # Aggregate across all query tokens
247
  if similarity_map.dtype == torch.bfloat16:
248
  similarity_map = similarity_map.float()
249
  aggregated_map = torch.mean(similarity_map, dim=0)
250
 
251
- # Convert the image to an array
252
  img_array = np.array(image.convert("RGBA"))
253
-
254
- # Normalize the similarity map
255
  similarity_map_array = normalize_similarity_map(aggregated_map).to(torch.float32).cpu().numpy()
256
  similarity_map_array = rearrange(similarity_map_array, "h w -> w h")
257
 
258
- # Create PIL image from similarity map
259
  similarity_map_image = Image.fromarray((similarity_map_array * 255).astype("uint8")).resize(
260
  image.size, Image.Resampling.BICUBIC
261
  )
262
 
263
  # Create matplotlib figure
264
- _, ax = plt.subplots(figsize=(10, 10))
265
  ax.imshow(img_array)
266
  ax.imshow(
267
  similarity_map_image,
@@ -284,210 +266,261 @@ def generate_colgemma_heatmap(
284
  print(f"❌ Heatmap generation error: {str(e)}")
285
  return image
286
 
287
- @spaces.GPU
288
- def query_bigemma(query: str, top_k: int) -> Tuple[str, List]:
289
- """Query indexed documents with BiGemma3."""
290
- # Ensure model is loaded
291
- if doc_index.bigemma_model is None:
292
- load_bigemma_model()
293
-
294
- model, processor = doc_index.bigemma_model, doc_index.bigemma_processor
295
-
296
- # Encode query
297
- batch_query = processor.process_texts([query]).to(device)
298
- query_embedding = model(**batch_query, embedding_dim=768)
299
-
300
- # Compute scores
301
- scores = processor.score(qs=query_embedding, ps=doc_index.bigemma_embeddings)
302
-
303
- # Get top-k results
304
- top_k_actual = min(top_k, len(doc_index.images))
305
- top_indices = scores[0].argsort(descending=True)[:top_k_actual]
306
-
307
- # Format results
308
- results_text = "### BiGemma3 (NetraEmbed) Results\n\n"
309
- gallery_images = []
310
-
311
- for rank, idx in enumerate(top_indices):
312
- score = scores[0, idx].item()
313
- results_text += f"**Rank {rank + 1}:** Page {idx.item() + 1} - Score: {score:.4f}\n"
314
- gallery_images.append(
315
- (doc_index.images[idx.item()], f"Rank {rank + 1} - Page {idx.item() + 1} (Score: {score:.4f})")
316
- )
317
-
318
- return results_text, gallery_images
319
 
320
  @spaces.GPU
321
- def query_colgemma(query: str, top_k: int, show_heatmap: bool = False) -> Tuple[str, List]:
322
- """Query indexed documents with ColGemma3."""
323
- # Ensure model is loaded
324
- if doc_index.colgemma_model is None:
325
- load_colgemma_model()
326
-
327
- model, processor = doc_index.colgemma_model, doc_index.colgemma_processor
328
-
329
- # Encode query
330
- batch_query = processor.process_queries([query]).to(device)
331
- query_embedding = model(**batch_query)
332
-
333
- # Compute scores
334
- scores = processor.score_multi_vector(qs=query_embedding, ps=doc_index.colgemma_embeddings)
335
-
336
- # Get top-k results
337
- top_k_actual = min(top_k, len(doc_index.images))
338
- top_indices = scores[0].argsort(descending=True)[:top_k_actual]
339
-
340
- # Format results
341
- results_text = "### ColGemma3 (ColNetraEmbed) Results\n\n"
342
- gallery_images = []
343
-
344
- for rank, idx in enumerate(top_indices):
345
- score = scores[0, idx].item()
346
- results_text += f"**Rank {rank + 1}:** Page {idx.item() + 1} - Score: {score:.2f}\n"
347
-
348
- # Generate heatmap if requested
349
- if show_heatmap:
350
- heatmap_image = generate_colgemma_heatmap(
351
- image=doc_index.images[idx.item()],
352
- query=query,
353
- query_embedding=query_embedding,
354
- image_embedding=doc_index.colgemma_embeddings[idx.item()].unsqueeze(0),
355
- model=model,
356
- processor=processor,
357
- )
358
- gallery_images.append(
359
- (heatmap_image, f"Rank {rank + 1} - Page {idx.item() + 1} (Score: {score:.2f})")
360
- )
361
- else:
362
- gallery_images.append(
363
- (doc_index.images[idx.item()], f"Rank {rank + 1} - Page {idx.item() + 1} (Score: {score:.2f})")
364
- )
365
-
366
- return results_text, gallery_images
367
-
368
  def query_documents(
369
  query: str, model_choice: str, top_k: int, show_heatmap: bool = False
370
- ) -> Tuple[Optional[str], Optional[str], Optional[List], Optional[List]]:
371
  """Query the indexed documents."""
372
  if not doc_index.images:
373
- return "⚠️ Please upload and index a document first.", None, None, None
374
 
375
  if not query.strip():
376
- return "⚠️ Please enter a query.", None, None, None
377
 
378
  try:
379
- results_bi = None
380
- results_col = None
381
- gallery_images_bi = []
382
- gallery_images_col = []
383
 
384
  # Query with BiGemma3
385
  if model_choice in ["NetraEmbed (BiGemma3)", "Both"]:
386
  if doc_index.bigemma_embeddings is None:
387
- return "⚠️ Please index the document with BiGemma3 first.", None, None, None
388
- results_bi, gallery_images_bi = query_bigemma(query, top_k)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
389
 
390
  # Query with ColGemma3
391
  if model_choice in ["ColNetraEmbed (ColGemma3)", "Both"]:
392
  if doc_index.colgemma_embeddings is None:
393
- return "⚠️ Please index the document with ColGemma3 first.", None, None, None
394
- results_col, gallery_images_col = query_colgemma(query, top_k, show_heatmap)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
395
 
396
  # Return results based on model choice
397
  if model_choice == "NetraEmbed (BiGemma3)":
398
- return results_bi, None, gallery_images_bi, None
399
  elif model_choice == "ColNetraEmbed (ColGemma3)":
400
- return results_col, None, None, gallery_images_col
401
  else: # Both
402
- return results_bi, results_col, gallery_images_bi, gallery_images_col
403
 
404
  except Exception as e:
405
  import traceback
406
  error_details = traceback.format_exc()
407
  print(f"Query error: {error_details}")
408
- return f"❌ Error during query: {str(e)}", None, None, None
 
409
 
410
  # Create Gradio interface
411
  with gr.Blocks(title="NetraEmbed Demo") as demo:
412
  # Header section
413
- gr.Markdown("# NetraEmbed")
414
- gr.HTML(
415
- """
416
- <div style="display: flex; gap: 8px; flex-wrap: wrap; margin-bottom: 15px;">
417
- <a href="https://arxiv.org/abs/2512.03514" target="_blank">
418
- <img src="https://img.shields.io/badge/arXiv-2512.03514-b31b1b.svg" alt="Paper">
419
- </a>
420
- <a href="https://github.com/adithya-s-k/colpali" target="_blank">
421
- <img src="https://img.shields.io/badge/GitHub-colpali-181717?logo=github" alt="GitHub">
422
- </a>
423
- <a href="https://huggingface.co/Cognitive-Lab/ColNetraEmbed" target="_blank">
424
- <img src="https://img.shields.io/badge/πŸ€—%20HuggingFace-Model-yellow" alt="Model">
425
- </a>
426
- </div>
427
- """
428
- )
429
- gr.Markdown(
430
- """
431
- **πŸš€ Universal Multilingual Multimodal Document Retrieval**
 
 
 
 
 
 
 
432
 
433
- Upload a PDF document, select your model(s), and query using semantic search.
434
 
435
- **Available Models:**
436
- - **NetraEmbed (BiGemma3)**: Single-vector embedding - Fast retrieval with cosine similarity
437
- - **ColNetraEmbed (ColGemma3)**: Multi-vector embedding - High-quality retrieval with MaxSim scoring and heatmaps
438
- """
439
- )
 
 
 
 
 
440
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
441
  with gr.Row():
442
- # Column 1: Model Selection
443
  with gr.Column(scale=1):
444
- gr.Markdown("### πŸ€– Model Selection")
445
  model_select = gr.Radio(
446
  choices=["NetraEmbed (BiGemma3)", "ColNetraEmbed (ColGemma3)", "Both"],
447
  value="Both",
448
  label="Select Model(s)",
449
  )
450
 
451
- # Column 2: Document Upload
452
- with gr.Column(scale=1):
453
- gr.Markdown("### πŸ“„ Upload & Index")
454
- pdf_upload = gr.File(label="Upload PDF", file_types=[".pdf"])
455
- index_btn = gr.Button("πŸ“₯ Index Document", variant="primary")
456
- index_status = gr.Textbox(label="Status", lines=6, interactive=False)
457
 
458
- # Column 3: Query
459
- with gr.Column(scale=1):
460
- gr.Markdown("### πŸ”Ž Query Document")
 
 
 
 
 
 
 
461
  query_input = gr.Textbox(
462
  label="Enter Query",
463
  placeholder="e.g., financial report, organizational structure...",
464
  lines=2,
465
  )
 
466
  with gr.Row():
467
- top_k_slider = gr.Slider(minimum=1, maximum=10, value=5, step=1, label="Top K", scale=2)
468
- heatmap_checkbox = gr.Checkbox(label="Heatmaps", value=False, scale=1)
469
- query_btn = gr.Button("πŸ” Search", variant="primary")
 
 
 
 
 
 
 
 
 
 
 
 
470
 
471
  gr.Markdown("---")
 
472
 
473
  # Results section
474
- gr.Markdown("### πŸ“Š Results")
475
- with gr.Row():
476
  with gr.Column(scale=1):
477
- bigemma_results = gr.Markdown(value="*BiGemma3 results will appear here...*")
 
 
478
  bigemma_gallery = gr.Gallery(
479
  label="BiGemma3 - Top Retrieved Pages",
 
480
  columns=2,
481
  height="auto",
 
482
  )
483
  with gr.Column(scale=1):
484
- colgemma_results = gr.Markdown(value="*ColGemma3 results will appear here...*")
 
 
485
  colgemma_gallery = gr.Gallery(
486
  label="ColGemma3 - Top Retrieved Pages",
 
487
  columns=2,
488
  height="auto",
 
489
  )
490
 
 
 
 
 
 
 
 
 
 
 
 
491
  # Event handlers
492
  index_btn.click(
493
  fn=index_document,
@@ -498,8 +531,11 @@ with gr.Blocks(title="NetraEmbed Demo") as demo:
498
  query_btn.click(
499
  fn=query_documents,
500
  inputs=[query_input, model_select, top_k_slider, heatmap_checkbox],
501
- outputs=[bigemma_results, colgemma_results, bigemma_gallery, colgemma_gallery],
502
  )
503
 
504
- # Launch the app
505
- demo.launch()
 
 
 
 
1
  """
2
+ NetraEmbed Demo - Document Retrieval with BiGemma3 and ColGemma3
 
 
 
 
 
 
 
 
 
 
 
 
3
 
4
+ This demo allows you to:
5
+ 1. Select a model (NetraEmbed, ColNetraEmbed, or Both)
6
+ 2. Upload PDF files and index them
7
+ 3. Search for relevant pages based on your query
8
+
9
+ HuggingFace Spaces deployment with ZeroGPU support.
10
+ """
11
 
 
 
12
  import spaces
13
+ import torch
14
+ import gradio as gr
15
  from pdf2image import convert_from_path
16
  from PIL import Image
17
+ from typing import List, Tuple, Optional
18
+ import math
19
+ import io
20
  import matplotlib.pyplot as plt
21
  import numpy as np
22
  import seaborn as sns
 
27
  from colpali_engine.interpretability import get_similarity_maps_from_embeddings
28
  from colpali_engine.interpretability.similarity_map_utils import normalize_similarity_map
29
 
 
 
30
  device = "cuda" if torch.cuda.is_available() else "cpu"
31
 
32
  print(f"Device: {device}")
 
46
 
47
  doc_index = DocumentIndex()
48
 
 
 
 
 
 
 
 
 
 
 
 
49
 
50
  @spaces.GPU
51
  def load_bigemma_model():
52
  """Load BiGemma3 model and processor."""
53
  if doc_index.bigemma_model is None:
54
  print("Loading BiGemma3 (NetraEmbed)...")
55
+ doc_index.bigemma_processor = BiGemmaProcessor3.from_pretrained(
56
+ "Cognitive-Lab/NetraEmbed",
57
+ use_fast=True,
58
+ )
59
+ doc_index.bigemma_model = BiGemma3.from_pretrained(
60
+ "Cognitive-Lab/NetraEmbed",
61
+ torch_dtype=torch.bfloat16,
62
+ device_map=device,
63
+ ).eval()
64
+ print("βœ“ BiGemma3 loaded successfully")
65
+ return doc_index.bigemma_model, doc_index.bigemma_processor
66
+
 
 
 
 
67
 
68
  @spaces.GPU
69
  def load_colgemma_model():
70
  """Load ColGemma3 model and processor."""
71
  if doc_index.colgemma_model is None:
72
  print("Loading ColGemma3 (ColNetraEmbed)...")
73
+ doc_index.colgemma_model = ColGemma3.from_pretrained(
74
+ "Cognitive-Lab/ColNetraEmbed",
75
+ dtype=torch.bfloat16,
76
+ device_map=device,
77
+ ).eval()
78
+ doc_index.colgemma_processor = ColGemmaProcessor3.from_pretrained(
79
+ "Cognitive-Lab/ColNetraEmbed",
80
+ use_fast=True,
81
+ )
82
+ print("βœ“ ColGemma3 loaded successfully")
83
+ return doc_index.colgemma_model, doc_index.colgemma_processor
84
+
85
+
86
+ def pdf_to_images(pdf_paths: List[str]) -> List[Image.Image]:
87
+ """Convert PDF files to list of PIL Images."""
88
+ images = []
89
+ for pdf_path in pdf_paths:
90
  try:
91
+ print(f"Converting PDF to images: {pdf_path}")
92
+ page_images = convert_from_path(pdf_path, dpi=200)
93
+ images.extend(page_images)
94
+ print(f"Converted {len(page_images)} pages from {pdf_path}")
 
 
 
 
 
 
 
95
  except Exception as e:
96
+ print(f"❌ PDF conversion error for {pdf_path}: {str(e)}")
97
+ raise gr.Error(f"Failed to convert PDF: {str(e)}")
98
+
99
+ if len(images) >= 150:
100
+ raise gr.Error("The number of images should be less than 150.")
101
+
102
+ return images
103
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
104
 
105
  @spaces.GPU
106
+ def index_bigemma_images(images: List[Image.Image]):
107
+ """Index images with BiGemma3."""
108
+ model, processor = load_bigemma_model()
109
+
110
+ print(f"Indexing {len(images)} images with BiGemma3...")
111
+ embeddings_list = []
112
+
113
+ # Process in smaller batches to avoid memory issues
114
+ batch_size = 2
115
+ for i in range(0, len(images), batch_size):
116
+ batch = images[i:i+batch_size]
117
+ batch_images = processor.process_images(batch).to(device)
118
+
119
+ with torch.no_grad():
120
+ embeddings = model(**batch_images, embedding_dim=768)
121
+ embeddings_list.append(embeddings.cpu())
122
+
123
+ # Concatenate all embeddings
124
+ all_embeddings = torch.cat(embeddings_list, dim=0)
125
+ print(f"βœ“ Indexed {len(images)} pages with BiGemma3 (shape: {all_embeddings.shape})")
126
+
127
+ return all_embeddings
128
 
 
 
 
 
129
 
130
  @spaces.GPU
131
+ def index_colgemma_images(images: List[Image.Image]):
132
+ """Index images with ColGemma3."""
133
+ model, processor = load_colgemma_model()
134
+
135
+ print(f"Indexing {len(images)} images with ColGemma3...")
136
+ embeddings_list = []
137
+
138
+ # Process in smaller batches to avoid memory issues
139
+ batch_size = 2
140
+ for i in range(0, len(images), batch_size):
141
+ batch = images[i:i+batch_size]
142
+ batch_images = processor.process_images(batch).to(device)
143
+
144
+ with torch.no_grad():
145
+ embeddings = model(**batch_images)
146
+ embeddings_list.append(embeddings.cpu())
147
+
148
+ # Concatenate all embeddings
149
+ all_embeddings = torch.cat(embeddings_list, dim=0)
150
+ print(f"βœ“ Indexed {len(images)} pages with ColGemma3 (shape: {all_embeddings.shape})")
151
 
152
+ return all_embeddings
 
 
 
153
 
154
+
155
+ def index_document(pdf_files, model_choice: str) -> str:
156
+ """Upload and index PDF documents."""
157
+ if not pdf_files:
158
+ return "⚠️ Please upload PDF documents first."
159
+
160
+ if not model_choice:
161
+ return "⚠️ Please select a model first."
162
 
163
  try:
164
+ status_messages = []
165
 
166
+ # Convert PDFs to images
167
+ status_messages.append("⏳ Converting PDFs to images...")
168
+ pdf_paths = [f.name for f in pdf_files]
169
+ doc_index.images = pdf_to_images(pdf_paths)
170
  num_pages = len(doc_index.images)
171
+ status_messages.append(f"βœ“ Converted to {num_pages} images")
 
 
 
172
 
173
  # Index with BiGemma3
174
  if model_choice in ["NetraEmbed (BiGemma3)", "Both"]:
175
+ status_messages.append("⏳ Indexing with BiGemma3...")
176
  doc_index.bigemma_embeddings = index_bigemma_images(doc_index.images)
177
+ status_messages.append("βœ“ Indexed with BiGemma3")
178
 
179
  # Index with ColGemma3
180
  if model_choice in ["ColNetraEmbed (ColGemma3)", "Both"]:
181
+ status_messages.append("⏳ Indexing with ColGemma3...")
182
  doc_index.colgemma_embeddings = index_colgemma_images(doc_index.images)
183
+ status_messages.append("βœ“ Indexed with ColGemma3")
184
 
185
+ final_status = "\n".join(status_messages) + "\n\nβœ… Document ready for querying!"
186
+ return final_status
187
 
188
  except Exception as e:
189
  import traceback
 
191
  print(f"Indexing error: {error_details}")
192
  return f"❌ Error indexing document: {str(e)}"
193
 
194
+
195
  @spaces.GPU
196
  def generate_colgemma_heatmap(
197
  image: Image.Image,
 
198
  query_embedding: torch.Tensor,
199
  image_embedding: torch.Tensor,
 
 
200
  ) -> Image.Image:
201
  """Generate heatmap overlay for ColGemma3 results."""
202
  try:
203
+ model, processor = load_colgemma_model()
204
+
205
+ # Re-process the single image
206
  batch_images = processor.process_images([image]).to(device)
207
 
208
+ # Create image mask
209
  if "input_ids" in batch_images and hasattr(model.config, "image_token_id"):
210
  image_token_id = model.config.image_token_id
211
  image_mask = batch_images["input_ids"] == image_token_id
212
  else:
213
  image_mask = torch.ones(
214
+ image_embedding.shape[0], image_embedding.shape[1],
215
+ dtype=torch.bool, device=device
216
  )
217
 
218
+ # Calculate n_patches
219
  num_image_tokens = image_mask.sum().item()
220
  n_side = int(math.sqrt(num_image_tokens))
221
+ n_patches = (n_side, n_side) if n_side * n_side == num_image_tokens else (16, 16)
 
 
 
 
222
 
223
  # Generate similarity maps
224
  similarity_maps_list = get_similarity_maps_from_embeddings(
225
+ image_embeddings=image_embedding.unsqueeze(0).to(device),
226
+ query_embeddings=query_embedding.to(device),
227
  n_patches=n_patches,
228
  image_mask=image_mask,
229
  )
230
 
231
  similarity_map = similarity_maps_list[0]
 
 
232
  if similarity_map.dtype == torch.bfloat16:
233
  similarity_map = similarity_map.float()
234
  aggregated_map = torch.mean(similarity_map, dim=0)
235
 
236
+ # Create heatmap overlay
237
  img_array = np.array(image.convert("RGBA"))
 
 
238
  similarity_map_array = normalize_similarity_map(aggregated_map).to(torch.float32).cpu().numpy()
239
  similarity_map_array = rearrange(similarity_map_array, "h w -> w h")
240
 
 
241
  similarity_map_image = Image.fromarray((similarity_map_array * 255).astype("uint8")).resize(
242
  image.size, Image.Resampling.BICUBIC
243
  )
244
 
245
  # Create matplotlib figure
246
+ fig, ax = plt.subplots(figsize=(10, 10))
247
  ax.imshow(img_array)
248
  ax.imshow(
249
  similarity_map_image,
 
266
  print(f"❌ Heatmap generation error: {str(e)}")
267
  return image
268
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
269
 
270
  @spaces.GPU
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
271
  def query_documents(
272
  query: str, model_choice: str, top_k: int, show_heatmap: bool = False
273
+ ) -> Tuple[Optional[List], Optional[str], Optional[List], Optional[str]]:
274
  """Query the indexed documents."""
275
  if not doc_index.images:
276
+ return None, "⚠️ Please upload and index a document first.", None, None
277
 
278
  if not query.strip():
279
+ return None, "⚠️ Please enter a query.", None, None
280
 
281
  try:
282
+ bigemma_results = []
283
+ bigemma_text = ""
284
+ colgemma_results = []
285
+ colgemma_text = ""
286
 
287
  # Query with BiGemma3
288
  if model_choice in ["NetraEmbed (BiGemma3)", "Both"]:
289
  if doc_index.bigemma_embeddings is None:
290
+ return None, "⚠️ Please index the document with BiGemma3 first.", None, None
291
+
292
+ model, processor = load_bigemma_model()
293
+
294
+ # Encode query
295
+ batch_query = processor.process_texts([query]).to(device)
296
+ with torch.no_grad():
297
+ query_embedding = model(**batch_query, embedding_dim=768)
298
+
299
+ # Compute scores
300
+ scores = processor.score(
301
+ qs=[query_embedding[0].cpu()],
302
+ ps=list(torch.unbind(doc_index.bigemma_embeddings)),
303
+ device=device,
304
+ )
305
+
306
+ # Get top-k results
307
+ top_k_actual = min(top_k, len(doc_index.images))
308
+ top_indices = scores[0].argsort(descending=True)[:top_k_actual]
309
+
310
+ # Format results
311
+ bigemma_text = "### BiGemma3 (NetraEmbed) Results\n\n"
312
+ for rank, idx in enumerate(top_indices):
313
+ score = scores[0, idx].item()
314
+ bigemma_text += f"**Rank {rank + 1}:** Page {idx.item() + 1} - Score: {score:.4f}\n"
315
+ bigemma_results.append(
316
+ (doc_index.images[idx.item()], f"Rank {rank + 1} - Page {idx.item() + 1} (Score: {score:.4f})")
317
+ )
318
 
319
  # Query with ColGemma3
320
  if model_choice in ["ColNetraEmbed (ColGemma3)", "Both"]:
321
  if doc_index.colgemma_embeddings is None:
322
+ return bigemma_results if bigemma_results else None, bigemma_text if bigemma_text else "⚠️ Please index the document with ColGemma3 first.", None, None
323
+
324
+ model, processor = load_colgemma_model()
325
+
326
+ # Encode query
327
+ batch_query = processor.process_queries([query]).to(device)
328
+ with torch.no_grad():
329
+ query_embedding = model(**batch_query)
330
+
331
+ # Compute scores
332
+ scores = processor.score_multi_vector(
333
+ qs=[query_embedding[0].cpu()],
334
+ ps=list(torch.unbind(doc_index.colgemma_embeddings)),
335
+ device=device,
336
+ )
337
+
338
+ # Get top-k results
339
+ top_k_actual = min(top_k, len(doc_index.images))
340
+ top_indices = scores[0].argsort(descending=True)[:top_k_actual]
341
+
342
+ # Format results
343
+ colgemma_text = "### ColGemma3 (ColNetraEmbed) Results\n\n"
344
+ for rank, idx in enumerate(top_indices):
345
+ score = scores[0, idx].item()
346
+ colgemma_text += f"**Rank {rank + 1}:** Page {idx.item() + 1} - Score: {score:.2f}\n"
347
+
348
+ # Generate heatmap if requested
349
+ if show_heatmap:
350
+ heatmap_image = generate_colgemma_heatmap(
351
+ image=doc_index.images[idx.item()],
352
+ query_embedding=query_embedding,
353
+ image_embedding=doc_index.colgemma_embeddings[idx.item()],
354
+ )
355
+ colgemma_results.append(
356
+ (heatmap_image, f"Rank {rank + 1} - Page {idx.item() + 1} (Score: {score:.2f})")
357
+ )
358
+ else:
359
+ colgemma_results.append(
360
+ (doc_index.images[idx.item()], f"Rank {rank + 1} - Page {idx.item() + 1} (Score: {score:.2f})")
361
+ )
362
 
363
  # Return results based on model choice
364
  if model_choice == "NetraEmbed (BiGemma3)":
365
+ return bigemma_results, bigemma_text, None, None
366
  elif model_choice == "ColNetraEmbed (ColGemma3)":
367
+ return None, None, colgemma_results, colgemma_text
368
  else: # Both
369
+ return bigemma_results, bigemma_text, colgemma_results, colgemma_text
370
 
371
  except Exception as e:
372
  import traceback
373
  error_details = traceback.format_exc()
374
  print(f"Query error: {error_details}")
375
+ return None, f"❌ Error during query: {str(e)}", None, None
376
+
377
 
378
  # Create Gradio interface
379
  with gr.Blocks(title="NetraEmbed Demo") as demo:
380
  # Header section
381
+ with gr.Row():
382
+ with gr.Column(scale=1):
383
+ gr.Markdown("# NetraEmbed")
384
+ gr.HTML(
385
+ """
386
+ <div style="display: flex; gap: 8px; flex-wrap: wrap; margin-bottom: 15px;">
387
+ <a href="https://arxiv.org/abs/2512.03514" target="_blank">
388
+ <img src="https://img.shields.io/badge/arXiv-2512.03514-b31b1b.svg" alt="Paper">
389
+ </a>
390
+ <a href="https://github.com/adithya-s-k/colpali" target="_blank">
391
+ <img src="https://img.shields.io/badge/GitHub-colpali-181717?logo=github" alt="GitHub">
392
+ </a>
393
+ <a href="https://huggingface.co/Cognitive-Lab/ColNetraEmbed" target="_blank">
394
+ <img src="https://img.shields.io/badge/πŸ€—%20HuggingFace-Model-yellow" alt="Model">
395
+ </a>
396
+ <a href="https://www.cognitivelab.in/blog/introducing-netraembed" target="_blank">
397
+ <img src="https://img.shields.io/badge/Blog-CognitiveLab-blue" alt="Blog">
398
+ </a>
399
+ <a href="https://cloud.cognitivelab.in" target="_blank">
400
+ <img src="https://img.shields.io/badge/Demo-Try%20it%20out-green" alt="Demo">
401
+ </a>
402
+ </div>
403
+ """
404
+ )
405
+ gr.Markdown(
406
+ """
407
 
408
+ **πŸš€ Universal Multilingual Multimodal Document Retrieval**
409
 
410
+ Upload a PDF document, select your model(s), and query using semantic search.
411
+
412
+ **Available Models:**
413
+ - **NetraEmbed (BiGemma3)**: Single-vector embedding with Matryoshka representation
414
+ Fast retrieval with cosine similarity
415
+ - **ColNetraEmbed (ColGemma3)**: Multi-vector embedding with late interaction
416
+ High-quality retrieval with MaxSim scoring and attention heatmaps
417
+
418
+ """
419
+ )
420
 
421
+ with gr.Column(scale=1):
422
+ gr.HTML(
423
+ """
424
+ <div style="text-align: center;">
425
+ <img src="https://cdn-uploads.huggingface.co/production/uploads/6442d975ad54813badc1ddf7/-fYMikXhSuqRqm-UIdulK.png"
426
+ alt="NetraEmbed Banner"
427
+ style="width: 100%; height: auto; border-radius: 8px;">
428
+ </div>
429
+ """
430
+ )
431
+
432
+ gr.Markdown("---")
433
+
434
+ # Main interface
435
  with gr.Row():
436
+ # Column 1: Model & Upload
437
  with gr.Column(scale=1):
438
+ gr.Markdown("### πŸ€– Select Model & Upload")
439
  model_select = gr.Radio(
440
  choices=["NetraEmbed (BiGemma3)", "ColNetraEmbed (ColGemma3)", "Both"],
441
  value="Both",
442
  label="Select Model(s)",
443
  )
444
 
445
+ pdf_upload = gr.File(
446
+ label="Upload PDFs",
447
+ file_types=[".pdf"],
448
+ file_count="multiple"
449
+ )
450
+ index_btn = gr.Button("πŸ“₯ Index Documents", variant="primary", size="sm")
451
 
452
+ index_status = gr.Textbox(
453
+ label="Indexing Status",
454
+ lines=8,
455
+ interactive=False,
456
+ value="Select model and upload PDFs to start",
457
+ )
458
+
459
+ # Column 2: Query & Results
460
+ with gr.Column(scale=2):
461
+ gr.Markdown("### πŸ”Ž Query Documents")
462
  query_input = gr.Textbox(
463
  label="Enter Query",
464
  placeholder="e.g., financial report, organizational structure...",
465
  lines=2,
466
  )
467
+
468
  with gr.Row():
469
+ top_k_slider = gr.Slider(
470
+ minimum=1,
471
+ maximum=10,
472
+ value=5,
473
+ step=1,
474
+ label="Top K Results",
475
+ scale=2,
476
+ )
477
+ heatmap_checkbox = gr.Checkbox(
478
+ label="Show Heatmaps (ColGemma3)",
479
+ value=False,
480
+ scale=1,
481
+ )
482
+
483
+ query_btn = gr.Button("πŸ” Search", variant="primary", size="sm")
484
 
485
  gr.Markdown("---")
486
+ gr.Markdown("### πŸ“Š Results")
487
 
488
  # Results section
489
+ with gr.Row(equal_height=True):
 
490
  with gr.Column(scale=1):
491
+ bigemma_results_text = gr.Markdown(
492
+ value="*BiGemma3 results will appear here...*",
493
+ )
494
  bigemma_gallery = gr.Gallery(
495
  label="BiGemma3 - Top Retrieved Pages",
496
+ show_label=True,
497
  columns=2,
498
  height="auto",
499
+ object_fit="contain",
500
  )
501
  with gr.Column(scale=1):
502
+ colgemma_results_text = gr.Markdown(
503
+ value="*ColGemma3 results will appear here...*",
504
+ )
505
  colgemma_gallery = gr.Gallery(
506
  label="ColGemma3 - Top Retrieved Pages",
507
+ show_label=True,
508
  columns=2,
509
  height="auto",
510
+ object_fit="contain",
511
  )
512
 
513
+ # Tips
514
+ with gr.Accordion("πŸ’‘ Tips", open=False):
515
+ gr.Markdown(
516
+ """
517
+ - **Both models**: Compare results side-by-side
518
+ - **Scores**: BiGemma3 uses cosine similarity (-1 to 1), ColGemma3 uses MaxSim (higher is better)
519
+ - **Heatmaps**: Enable to visualize ColGemma3 attention patterns (brighter = higher attention)
520
+ - **Refresh**: If you change documents, refresh the page to clear the index
521
+ """
522
+ )
523
+
524
  # Event handlers
525
  index_btn.click(
526
  fn=index_document,
 
531
  query_btn.click(
532
  fn=query_documents,
533
  inputs=[query_input, model_select, top_k_slider, heatmap_checkbox],
534
+ outputs=[bigemma_gallery, bigemma_results_text, colgemma_gallery, colgemma_results_text],
535
  )
536
 
537
+ # Enable queue for handling multiple requests
538
+ demo.queue(max_size=20)
539
+
540
+ if __name__ == "__main__":
541
+ demo.launch(debug=True)