Spaces:
Running
on
L4
Running
on
L4
File size: 8,945 Bytes
0e84795 191dbfa 0e84795 191dbfa 0e84795 191dbfa 0e84795 191dbfa 0e84795 191dbfa f59fe24 191dbfa 0e84795 191dbfa 0e84795 191dbfa 0e84795 191dbfa 644a908 191dbfa 644a908 0f883aa 191dbfa 0f883aa 191dbfa 0f883aa |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 |
import gradio as gr
import random
import torch
import numpy as np
from PIL import Image, ImageOps
import os
import json
import sys
import multiprocessing
from concurrent.futures import ProcessPoolExecutor
import time
# Assume MagicQuill and other dependencies are present as per user instruction
from MagicQuill import folder_paths
from MagicQuill.llava_new import LLaVAModel
from huggingface_hub import snapshot_download
# Imports for SAM (Only needed in worker process, but imported here for checking)
from segment_anything import sam_model_registry, SamPredictor
# Download models (Main process does this once)
hf_token = os.environ.get("HF_TOKEN")
snapshot_download(repo_id="LiuZichen/MagicQuill-models", repo_type="model", local_dir="models")
snapshot_download(repo_id="LiuZichen/MagicQuillV2-models", repo_type="model", local_dir="models_v2", token=hf_token)
# --- Global Models for Main Process ---
print("Initializing LLaVAModel (Main Process)...")
# LLaVA is stateless/thread-safe enough or too big to duplicate, so we keep it in main process (or use threads)
llavaModel = LLaVAModel()
print("LLaVAModel initialized.")
# --- Worker Process Logic for SAM ---
# Global variable for the worker process to hold its own SAM instance
worker_sam = None
def init_worker_sam(device='cuda'):
"""
This function is called when a new worker process starts.
It initializes a standalone SAM model for that process.
"""
global worker_sam
print(f"Process {os.getpid()}: Initializing SAM model...")
# Define SAM class locally or import it. Since it was defined in the script,
# we can redefine a helper or import the logic.
# Ideally, the SAM logic should be in a separate module to be picklable easily.
# But for this script, we can define the loading logic here.
checkpoint_path = 'models_v2/sam/sam_vit_b_01ec64.pth'
# Load Model
try:
sam = sam_model_registry['vit_b'](checkpoint=checkpoint_path)
sam.to(device=device)
predictor = SamPredictor(sam)
worker_sam = {
"predictor": predictor
}
print(f"Process {os.getpid()}: SAM initialized.")
except Exception as e:
print(f"Process {os.getpid()}: Failed to init SAM: {e}")
def run_sam_inference(image_np, coordinates_positive, coordinates_negative, bboxes):
"""
The actual inference function running inside the worker process.
"""
global worker_sam
if worker_sam is None:
# Fallback if init didn't run or failed (though ProcessPool initializer should handle it)
init_worker_sam()
predictor = worker_sam["predictor"]
# Set Image
predictor.set_image(image_np)
input_point = []
input_label = []
# Process points
if coordinates_positive:
coords = json.loads(coordinates_positive) if isinstance(coordinates_positive, str) else coordinates_positive
for p in coords:
input_point.append([p['x'], p['y']])
input_label.append(1)
if coordinates_negative:
coords = json.loads(coordinates_negative) if isinstance(coordinates_negative, str) else coordinates_negative
for p in coords:
input_point.append([p['x'], p['y']])
input_label.append(0)
# Process bbox
input_box = None
if bboxes:
if isinstance(bboxes, str):
try:
bboxes = json.loads(bboxes)
except:
pass
box_list = []
if isinstance(bboxes, list):
for box in bboxes:
box_list.append(list(box))
if len(box_list) > 0:
input_box = np.array(box_list)
if len(input_point) > 0:
input_point = np.array(input_point)
input_label = np.array(input_label)
else:
input_point = None
input_label = None
# Predict
masks, scores, logits = predictor.predict(
point_coords=input_point,
point_labels=input_label,
box=input_box,
multimask_output=False,
)
mask_np = masks[0]
# Post-processing
# Simply convert mask to uint8 [0, 255] for transport
if mask_np.dtype == bool:
mask_np = mask_np.astype(np.uint8) * 255
else:
mask_np = (mask_np > 0).astype(np.uint8) * 255
# Return mask as image for client to use
# We return mask_np twice to satisfy the function signature or unpacker in segment()
# segment() expects (image_with_alpha_np, mask_np)
return mask_np, mask_np
# --- Main Process Helpers ---
# We need a pool. Since we are in a script, we initialize it in main block.
sam_pool = None
def numpy_to_tensor(numpy_array):
tensor = torch.from_numpy(numpy_array).float().unsqueeze(0) / 255.
return tensor
def guess(original_image, add_color_image, add_edge_mask):
# LLaVA inference runs in the main process (threaded)
original_image_tensor = numpy_to_tensor(original_image)
add_color_image_tensor = numpy_to_tensor(add_color_image)
add_edge_mask_tensor = numpy_to_tensor(add_edge_mask)
description, ans1, ans2 = llavaModel.process(original_image_tensor, add_color_image_tensor, add_edge_mask_tensor)
ans_list = []
if ans1 and ans1 != "":
ans_list.append(ans1)
if ans2 and ans2 != "":
ans_list.append(ans2)
return ", ".join(ans_list)
def get_mask_bbox(mask_np):
# mask_np: [1, H, W] or [H, W]
if mask_np.ndim == 3:
mask_np = mask_np[0]
rows = np.any(mask_np, axis=1)
cols = np.any(mask_np, axis=0)
if not np.any(rows) or not np.any(cols):
return None
y_min, y_max = np.where(rows)[0][[0, -1]]
x_min, x_max = np.where(cols)[0][[0, -1]]
return int(x_min), int(y_min), int(x_max), int(y_max)
def segment(image, coordinates_positive, coordinates_negative, bboxes):
# image: numpy array (uint8)
# Submit task to process pool
print("image.shape:", image.shape)
print("coordinates_positive:", coordinates_positive)
print("coordinates_negative:", coordinates_negative)
print("bboxes:", bboxes)
if sam_pool is None:
return None, json.dumps({'error': 'SAM pool not initialized'})
# Future result
future = sam_pool.submit(run_sam_inference, image, coordinates_positive, coordinates_negative, bboxes)
# Wait for result
image_with_alpha_np, mask_np = future.result(timeout=60) # 60s timeout
# Convert back to PIL for Gradio
res_pil = Image.fromarray(image_with_alpha_np)
# Calculate bbox
mask_bbox = get_mask_bbox(mask_np)
if mask_bbox:
x_min, y_min, x_max, y_max = mask_bbox
seg_bbox = {'startX': x_min, 'startY': y_min, 'endX': x_max, 'endY': y_max}
else:
seg_bbox = {'startX': 0, 'startY': 0, 'endX': 0, 'endY': 0}
return res_pil, json.dumps(seg_bbox)
# --- Gradio UI ---
with gr.Blocks() as app:
with gr.Row():
gr.Markdown("## MagicQuill Worker Server (Draw&Guess + SAM)")
with gr.Tab("Draw & Guess"):
with gr.Row():
dg_input_img = gr.Image(label="Original Image")
dg_color_img = gr.Image(label="Colored Image")
dg_edge_img = gr.Image(image_mode="L", label="Edge Mask")
dg_output = gr.Textbox(label="Prediction Output")
dg_btn = gr.Button("Guess")
dg_btn.click(
fn=guess,
inputs=[dg_input_img, dg_color_img, dg_edge_img],
outputs=dg_output,
api_name="guess_prompt",
concurrency_limit=1
)
with gr.Tab("SAM Segmentation"):
with gr.Row():
sam_input_img = gr.Image(label="Input Image", type="numpy")
sam_pos_coords = gr.Textbox(label="Pos Coords JSON")
sam_neg_coords = gr.Textbox(label="Neg Coords JSON")
sam_bboxes = gr.Textbox(label="BBoxes JSON")
with gr.Row():
sam_output_img = gr.Image(label="Segmented Image", format="png")
sam_output_bbox = gr.Textbox(label="Mask BBox JSON")
sam_btn = gr.Button("Segment")
sam_btn.click(
fn=segment,
inputs=[sam_input_img, sam_pos_coords, sam_neg_coords, sam_bboxes],
outputs=[sam_output_img, sam_output_bbox],
api_name="segment",
concurrency_limit=5
)
if __name__ == "__main__":
# Set start method to spawn for CUDA compatibility
multiprocessing.set_start_method('spawn', force=True)
# Initialize SAM Pool
# Adjust max_workers based on GPU memory (e.g., 2-4 workers for SAM-B)
NUM_SAM_WORKERS = 5
print(f"Starting {NUM_SAM_WORKERS} SAM worker processes...")
sam_pool = ProcessPoolExecutor(max_workers=NUM_SAM_WORKERS, initializer=init_worker_sam)
# Launch Gradio
app.queue(max_size=40).launch(max_threads=5)
|