Spaces:
Running
on
L4
Running
on
L4
| import gradio as gr | |
| import random | |
| import torch | |
| import numpy as np | |
| from PIL import Image, ImageOps | |
| import os | |
| import json | |
| import sys | |
| import multiprocessing | |
| from concurrent.futures import ProcessPoolExecutor | |
| import time | |
| # Assume MagicQuill and other dependencies are present as per user instruction | |
| from MagicQuill import folder_paths | |
| from MagicQuill.llava_new import LLaVAModel | |
| from huggingface_hub import snapshot_download | |
| # Imports for SAM (Only needed in worker process, but imported here for checking) | |
| from segment_anything import sam_model_registry, SamPredictor | |
| # Download models (Main process does this once) | |
| hf_token = os.environ.get("HF_TOKEN") | |
| snapshot_download(repo_id="LiuZichen/MagicQuill-models", repo_type="model", local_dir="models") | |
| snapshot_download(repo_id="LiuZichen/MagicQuillV2-models", repo_type="model", local_dir="models_v2", token=hf_token) | |
| # --- Global Models for Main Process --- | |
| print("Initializing LLaVAModel (Main Process)...") | |
| # LLaVA is stateless/thread-safe enough or too big to duplicate, so we keep it in main process (or use threads) | |
| llavaModel = LLaVAModel() | |
| print("LLaVAModel initialized.") | |
| # --- Worker Process Logic for SAM --- | |
| # Global variable for the worker process to hold its own SAM instance | |
| worker_sam = None | |
| def init_worker_sam(device='cuda'): | |
| """ | |
| This function is called when a new worker process starts. | |
| It initializes a standalone SAM model for that process. | |
| """ | |
| global worker_sam | |
| print(f"Process {os.getpid()}: Initializing SAM model...") | |
| # Define SAM class locally or import it. Since it was defined in the script, | |
| # we can redefine a helper or import the logic. | |
| # Ideally, the SAM logic should be in a separate module to be picklable easily. | |
| # But for this script, we can define the loading logic here. | |
| checkpoint_path = 'models_v2/sam/sam_vit_b_01ec64.pth' | |
| # Load Model | |
| try: | |
| sam = sam_model_registry['vit_b'](checkpoint=checkpoint_path) | |
| sam.to(device=device) | |
| predictor = SamPredictor(sam) | |
| worker_sam = { | |
| "predictor": predictor | |
| } | |
| print(f"Process {os.getpid()}: SAM initialized.") | |
| except Exception as e: | |
| print(f"Process {os.getpid()}: Failed to init SAM: {e}") | |
| def run_sam_inference(image_np, coordinates_positive, coordinates_negative, bboxes): | |
| """ | |
| The actual inference function running inside the worker process. | |
| """ | |
| global worker_sam | |
| if worker_sam is None: | |
| # Fallback if init didn't run or failed (though ProcessPool initializer should handle it) | |
| init_worker_sam() | |
| predictor = worker_sam["predictor"] | |
| # Set Image | |
| predictor.set_image(image_np) | |
| input_point = [] | |
| input_label = [] | |
| # Process points | |
| if coordinates_positive: | |
| coords = json.loads(coordinates_positive) if isinstance(coordinates_positive, str) else coordinates_positive | |
| for p in coords: | |
| input_point.append([p['x'], p['y']]) | |
| input_label.append(1) | |
| if coordinates_negative: | |
| coords = json.loads(coordinates_negative) if isinstance(coordinates_negative, str) else coordinates_negative | |
| for p in coords: | |
| input_point.append([p['x'], p['y']]) | |
| input_label.append(0) | |
| # Process bbox | |
| input_box = None | |
| if bboxes: | |
| if isinstance(bboxes, str): | |
| try: | |
| bboxes = json.loads(bboxes) | |
| except: | |
| pass | |
| box_list = [] | |
| if isinstance(bboxes, list): | |
| for box in bboxes: | |
| box_list.append(list(box)) | |
| if len(box_list) > 0: | |
| input_box = np.array(box_list) | |
| if len(input_point) > 0: | |
| input_point = np.array(input_point) | |
| input_label = np.array(input_label) | |
| else: | |
| input_point = None | |
| input_label = None | |
| # Predict | |
| masks, scores, logits = predictor.predict( | |
| point_coords=input_point, | |
| point_labels=input_label, | |
| box=input_box, | |
| multimask_output=False, | |
| ) | |
| mask_np = masks[0] | |
| # Post-processing | |
| # Simply convert mask to uint8 [0, 255] for transport | |
| if mask_np.dtype == bool: | |
| mask_np = mask_np.astype(np.uint8) * 255 | |
| else: | |
| mask_np = (mask_np > 0).astype(np.uint8) * 255 | |
| # Return mask as image for client to use | |
| # We return mask_np twice to satisfy the function signature or unpacker in segment() | |
| # segment() expects (image_with_alpha_np, mask_np) | |
| return mask_np, mask_np | |
| # --- Main Process Helpers --- | |
| # We need a pool. Since we are in a script, we initialize it in main block. | |
| sam_pool = None | |
| def numpy_to_tensor(numpy_array): | |
| tensor = torch.from_numpy(numpy_array).float().unsqueeze(0) / 255. | |
| return tensor | |
| def guess(original_image, add_color_image, add_edge_mask): | |
| # LLaVA inference runs in the main process (threaded) | |
| original_image_tensor = numpy_to_tensor(original_image) | |
| add_color_image_tensor = numpy_to_tensor(add_color_image) | |
| add_edge_mask_tensor = numpy_to_tensor(add_edge_mask) | |
| description, ans1, ans2 = llavaModel.process(original_image_tensor, add_color_image_tensor, add_edge_mask_tensor) | |
| ans_list = [] | |
| if ans1 and ans1 != "": | |
| ans_list.append(ans1) | |
| if ans2 and ans2 != "": | |
| ans_list.append(ans2) | |
| return ", ".join(ans_list) | |
| def get_mask_bbox(mask_np): | |
| # mask_np: [1, H, W] or [H, W] | |
| if mask_np.ndim == 3: | |
| mask_np = mask_np[0] | |
| rows = np.any(mask_np, axis=1) | |
| cols = np.any(mask_np, axis=0) | |
| if not np.any(rows) or not np.any(cols): | |
| return None | |
| y_min, y_max = np.where(rows)[0][[0, -1]] | |
| x_min, x_max = np.where(cols)[0][[0, -1]] | |
| return int(x_min), int(y_min), int(x_max), int(y_max) | |
| def segment(image, coordinates_positive, coordinates_negative, bboxes): | |
| # image: numpy array (uint8) | |
| # Submit task to process pool | |
| print("image.shape:", image.shape) | |
| print("coordinates_positive:", coordinates_positive) | |
| print("coordinates_negative:", coordinates_negative) | |
| print("bboxes:", bboxes) | |
| if sam_pool is None: | |
| return None, json.dumps({'error': 'SAM pool not initialized'}) | |
| # Future result | |
| future = sam_pool.submit(run_sam_inference, image, coordinates_positive, coordinates_negative, bboxes) | |
| # Wait for result | |
| image_with_alpha_np, mask_np = future.result(timeout=60) # 60s timeout | |
| # Convert back to PIL for Gradio | |
| res_pil = Image.fromarray(image_with_alpha_np) | |
| # Calculate bbox | |
| mask_bbox = get_mask_bbox(mask_np) | |
| if mask_bbox: | |
| x_min, y_min, x_max, y_max = mask_bbox | |
| seg_bbox = {'startX': x_min, 'startY': y_min, 'endX': x_max, 'endY': y_max} | |
| else: | |
| seg_bbox = {'startX': 0, 'startY': 0, 'endX': 0, 'endY': 0} | |
| return res_pil, json.dumps(seg_bbox) | |
| # --- Gradio UI --- | |
| with gr.Blocks() as app: | |
| with gr.Row(): | |
| gr.Markdown("## MagicQuill Worker Server (Draw&Guess + SAM)") | |
| with gr.Tab("Draw & Guess"): | |
| with gr.Row(): | |
| dg_input_img = gr.Image(label="Original Image") | |
| dg_color_img = gr.Image(label="Colored Image") | |
| dg_edge_img = gr.Image(image_mode="L", label="Edge Mask") | |
| dg_output = gr.Textbox(label="Prediction Output") | |
| dg_btn = gr.Button("Guess") | |
| dg_btn.click( | |
| fn=guess, | |
| inputs=[dg_input_img, dg_color_img, dg_edge_img], | |
| outputs=dg_output, | |
| api_name="guess_prompt", | |
| concurrency_limit=1 | |
| ) | |
| with gr.Tab("SAM Segmentation"): | |
| with gr.Row(): | |
| sam_input_img = gr.Image(label="Input Image", type="numpy") | |
| sam_pos_coords = gr.Textbox(label="Pos Coords JSON") | |
| sam_neg_coords = gr.Textbox(label="Neg Coords JSON") | |
| sam_bboxes = gr.Textbox(label="BBoxes JSON") | |
| with gr.Row(): | |
| sam_output_img = gr.Image(label="Segmented Image", format="png") | |
| sam_output_bbox = gr.Textbox(label="Mask BBox JSON") | |
| sam_btn = gr.Button("Segment") | |
| sam_btn.click( | |
| fn=segment, | |
| inputs=[sam_input_img, sam_pos_coords, sam_neg_coords, sam_bboxes], | |
| outputs=[sam_output_img, sam_output_bbox], | |
| api_name="segment", | |
| concurrency_limit=5 | |
| ) | |
| if __name__ == "__main__": | |
| # Set start method to spawn for CUDA compatibility | |
| multiprocessing.set_start_method('spawn', force=True) | |
| # Initialize SAM Pool | |
| # Adjust max_workers based on GPU memory (e.g., 2-4 workers for SAM-B) | |
| NUM_SAM_WORKERS = 5 | |
| print(f"Starting {NUM_SAM_WORKERS} SAM worker processes...") | |
| sam_pool = ProcessPoolExecutor(max_workers=NUM_SAM_WORKERS, initializer=init_worker_sam) | |
| # Launch Gradio | |
| app.queue(max_size=40).launch(max_threads=5) | |