|
|
import numpy as np |
|
|
import matplotlib.pyplot as plt |
|
|
from PIL import Image |
|
|
from transformers import pipeline |
|
|
import torch |
|
|
import tifffile |
|
|
import gradio as gr |
|
|
import os |
|
|
|
|
|
|
|
|
print("Step 1: Setting up the environment...") |
|
|
device = 0 if torch.cuda.is_available() else -1 |
|
|
print(f" > Device selected: {'GPU' if device == 0 else 'CPU'}") |
|
|
|
|
|
|
|
|
print("Step 2: Loading SAM Model...") |
|
|
generator = pipeline("mask-generation", model="facebook/sam-vit-huge", device=device) |
|
|
print(" > SAM Model loaded successfully.") |
|
|
|
|
|
|
|
|
def segment_image(image): |
|
|
print("Step 3: Starting image segmentation...") |
|
|
|
|
|
|
|
|
print(" > Resizing image...") |
|
|
raw_image = image.convert("RGB") |
|
|
original_size = raw_image.size |
|
|
resized_size = (original_size[0] // 4, original_size[1] // 4) |
|
|
raw_image = raw_image.resize(resized_size) |
|
|
print(f" > Original size: {original_size}, Resized size: {resized_size}") |
|
|
|
|
|
|
|
|
print(" > Running SAM segmentation...") |
|
|
outputs = generator(raw_image, points_per_batch=64) |
|
|
masks = outputs["masks"] |
|
|
print(f" > {len(masks)} masks generated.") |
|
|
|
|
|
|
|
|
print(" > Creating labeled mask...") |
|
|
h, w = masks[0].shape |
|
|
labeled_mask = np.zeros((h, w), dtype=np.uint16) |
|
|
for i, mask in enumerate(masks): |
|
|
labeled_mask[mask] = i + 1 |
|
|
print(" > Labeled mask created.") |
|
|
|
|
|
|
|
|
print(" > Generating overlay...") |
|
|
overlay = np.zeros((h, w, 4)) |
|
|
np.random.seed(42) |
|
|
for label in np.unique(labeled_mask): |
|
|
if label == 0: |
|
|
continue |
|
|
color = np.random.rand(3) |
|
|
overlay[labeled_mask == label] = np.append(color, 0.5) |
|
|
print(" > Overlay generated.") |
|
|
|
|
|
|
|
|
output_path = "labeled_mask.tif" |
|
|
print(" > Saving labeled mask as TIFF...") |
|
|
tifffile.imwrite(output_path, labeled_mask) |
|
|
print(f" > Mask saved to: {output_path}") |
|
|
|
|
|
|
|
|
print("Step 4: Plotting results...") |
|
|
plt.figure(figsize=(15, 5)) |
|
|
|
|
|
|
|
|
plt.subplot(1, 2, 1) |
|
|
plt.imshow(image) |
|
|
plt.title("Original Image") |
|
|
plt.axis("off") |
|
|
|
|
|
|
|
|
plt.subplot(1, 2, 2) |
|
|
plt.imshow(raw_image) |
|
|
plt.imshow(overlay) |
|
|
plt.title("Segmented Overlay") |
|
|
plt.axis("off") |
|
|
|
|
|
plt.tight_layout() |
|
|
plt.savefig("segmented_overlay.png") |
|
|
plt.close() |
|
|
print(" > Results plotted.") |
|
|
|
|
|
return output_path |
|
|
|
|
|
|
|
|
|
|
|
print("Step 5: Setting up Gradio interface...") |
|
|
iface = gr.Interface( |
|
|
fn=segment_image, |
|
|
inputs=gr.Image(type="pil"), |
|
|
outputs=gr.File(label="Download Mask"), |
|
|
title="Image Segmentation with SAM", |
|
|
description="Upload an image to segment it and visualize the results." |
|
|
) |
|
|
|
|
|
|
|
|
print("Step 6: Launching the interface...") |
|
|
iface.launch() |
|
|
print(" > Interface launched successfully.") |