LiuZichen's picture
Update app.py
0f883aa verified
import gradio as gr
import random
import torch
import numpy as np
from PIL import Image, ImageOps
import os
import json
import sys
import multiprocessing
from concurrent.futures import ProcessPoolExecutor
import time
# Assume MagicQuill and other dependencies are present as per user instruction
from MagicQuill import folder_paths
from MagicQuill.llava_new import LLaVAModel
from huggingface_hub import snapshot_download
# Imports for SAM (Only needed in worker process, but imported here for checking)
from segment_anything import sam_model_registry, SamPredictor
# Download models (Main process does this once)
hf_token = os.environ.get("HF_TOKEN")
snapshot_download(repo_id="LiuZichen/MagicQuill-models", repo_type="model", local_dir="models")
snapshot_download(repo_id="LiuZichen/MagicQuillV2-models", repo_type="model", local_dir="models_v2", token=hf_token)
# --- Global Models for Main Process ---
print("Initializing LLaVAModel (Main Process)...")
# LLaVA is stateless/thread-safe enough or too big to duplicate, so we keep it in main process (or use threads)
llavaModel = LLaVAModel()
print("LLaVAModel initialized.")
# --- Worker Process Logic for SAM ---
# Global variable for the worker process to hold its own SAM instance
worker_sam = None
def init_worker_sam(device='cuda'):
"""
This function is called when a new worker process starts.
It initializes a standalone SAM model for that process.
"""
global worker_sam
print(f"Process {os.getpid()}: Initializing SAM model...")
# Define SAM class locally or import it. Since it was defined in the script,
# we can redefine a helper or import the logic.
# Ideally, the SAM logic should be in a separate module to be picklable easily.
# But for this script, we can define the loading logic here.
checkpoint_path = 'models_v2/sam/sam_vit_b_01ec64.pth'
# Load Model
try:
sam = sam_model_registry['vit_b'](checkpoint=checkpoint_path)
sam.to(device=device)
predictor = SamPredictor(sam)
worker_sam = {
"predictor": predictor
}
print(f"Process {os.getpid()}: SAM initialized.")
except Exception as e:
print(f"Process {os.getpid()}: Failed to init SAM: {e}")
def run_sam_inference(image_np, coordinates_positive, coordinates_negative, bboxes):
"""
The actual inference function running inside the worker process.
"""
global worker_sam
if worker_sam is None:
# Fallback if init didn't run or failed (though ProcessPool initializer should handle it)
init_worker_sam()
predictor = worker_sam["predictor"]
# Set Image
predictor.set_image(image_np)
input_point = []
input_label = []
# Process points
if coordinates_positive:
coords = json.loads(coordinates_positive) if isinstance(coordinates_positive, str) else coordinates_positive
for p in coords:
input_point.append([p['x'], p['y']])
input_label.append(1)
if coordinates_negative:
coords = json.loads(coordinates_negative) if isinstance(coordinates_negative, str) else coordinates_negative
for p in coords:
input_point.append([p['x'], p['y']])
input_label.append(0)
# Process bbox
input_box = None
if bboxes:
if isinstance(bboxes, str):
try:
bboxes = json.loads(bboxes)
except:
pass
box_list = []
if isinstance(bboxes, list):
for box in bboxes:
box_list.append(list(box))
if len(box_list) > 0:
input_box = np.array(box_list)
if len(input_point) > 0:
input_point = np.array(input_point)
input_label = np.array(input_label)
else:
input_point = None
input_label = None
# Predict
masks, scores, logits = predictor.predict(
point_coords=input_point,
point_labels=input_label,
box=input_box,
multimask_output=False,
)
mask_np = masks[0]
# Post-processing
# Simply convert mask to uint8 [0, 255] for transport
if mask_np.dtype == bool:
mask_np = mask_np.astype(np.uint8) * 255
else:
mask_np = (mask_np > 0).astype(np.uint8) * 255
# Return mask as image for client to use
# We return mask_np twice to satisfy the function signature or unpacker in segment()
# segment() expects (image_with_alpha_np, mask_np)
return mask_np, mask_np
# --- Main Process Helpers ---
# We need a pool. Since we are in a script, we initialize it in main block.
sam_pool = None
def numpy_to_tensor(numpy_array):
tensor = torch.from_numpy(numpy_array).float().unsqueeze(0) / 255.
return tensor
def guess(original_image, add_color_image, add_edge_mask):
# LLaVA inference runs in the main process (threaded)
original_image_tensor = numpy_to_tensor(original_image)
add_color_image_tensor = numpy_to_tensor(add_color_image)
add_edge_mask_tensor = numpy_to_tensor(add_edge_mask)
description, ans1, ans2 = llavaModel.process(original_image_tensor, add_color_image_tensor, add_edge_mask_tensor)
ans_list = []
if ans1 and ans1 != "":
ans_list.append(ans1)
if ans2 and ans2 != "":
ans_list.append(ans2)
return ", ".join(ans_list)
def get_mask_bbox(mask_np):
# mask_np: [1, H, W] or [H, W]
if mask_np.ndim == 3:
mask_np = mask_np[0]
rows = np.any(mask_np, axis=1)
cols = np.any(mask_np, axis=0)
if not np.any(rows) or not np.any(cols):
return None
y_min, y_max = np.where(rows)[0][[0, -1]]
x_min, x_max = np.where(cols)[0][[0, -1]]
return int(x_min), int(y_min), int(x_max), int(y_max)
def segment(image, coordinates_positive, coordinates_negative, bboxes):
# image: numpy array (uint8)
# Submit task to process pool
print("image.shape:", image.shape)
print("coordinates_positive:", coordinates_positive)
print("coordinates_negative:", coordinates_negative)
print("bboxes:", bboxes)
if sam_pool is None:
return None, json.dumps({'error': 'SAM pool not initialized'})
# Future result
future = sam_pool.submit(run_sam_inference, image, coordinates_positive, coordinates_negative, bboxes)
# Wait for result
image_with_alpha_np, mask_np = future.result(timeout=60) # 60s timeout
# Convert back to PIL for Gradio
res_pil = Image.fromarray(image_with_alpha_np)
# Calculate bbox
mask_bbox = get_mask_bbox(mask_np)
if mask_bbox:
x_min, y_min, x_max, y_max = mask_bbox
seg_bbox = {'startX': x_min, 'startY': y_min, 'endX': x_max, 'endY': y_max}
else:
seg_bbox = {'startX': 0, 'startY': 0, 'endX': 0, 'endY': 0}
return res_pil, json.dumps(seg_bbox)
# --- Gradio UI ---
with gr.Blocks() as app:
with gr.Row():
gr.Markdown("## MagicQuill Worker Server (Draw&Guess + SAM)")
with gr.Tab("Draw & Guess"):
with gr.Row():
dg_input_img = gr.Image(label="Original Image")
dg_color_img = gr.Image(label="Colored Image")
dg_edge_img = gr.Image(image_mode="L", label="Edge Mask")
dg_output = gr.Textbox(label="Prediction Output")
dg_btn = gr.Button("Guess")
dg_btn.click(
fn=guess,
inputs=[dg_input_img, dg_color_img, dg_edge_img],
outputs=dg_output,
api_name="guess_prompt",
concurrency_limit=1
)
with gr.Tab("SAM Segmentation"):
with gr.Row():
sam_input_img = gr.Image(label="Input Image", type="numpy")
sam_pos_coords = gr.Textbox(label="Pos Coords JSON")
sam_neg_coords = gr.Textbox(label="Neg Coords JSON")
sam_bboxes = gr.Textbox(label="BBoxes JSON")
with gr.Row():
sam_output_img = gr.Image(label="Segmented Image", format="png")
sam_output_bbox = gr.Textbox(label="Mask BBox JSON")
sam_btn = gr.Button("Segment")
sam_btn.click(
fn=segment,
inputs=[sam_input_img, sam_pos_coords, sam_neg_coords, sam_bboxes],
outputs=[sam_output_img, sam_output_bbox],
api_name="segment",
concurrency_limit=5
)
if __name__ == "__main__":
# Set start method to spawn for CUDA compatibility
multiprocessing.set_start_method('spawn', force=True)
# Initialize SAM Pool
# Adjust max_workers based on GPU memory (e.g., 2-4 workers for SAM-B)
NUM_SAM_WORKERS = 5
print(f"Starting {NUM_SAM_WORKERS} SAM worker processes...")
sam_pool = ProcessPoolExecutor(max_workers=NUM_SAM_WORKERS, initializer=init_worker_sam)
# Launch Gradio
app.queue(max_size=40).launch(max_threads=5)