File size: 8,945 Bytes
0e84795
 
 
 
 
191dbfa
 
 
 
 
 
 
 
0e84795
 
 
191dbfa
 
 
 
 
 
0e84795
191dbfa
0e84795
191dbfa
 
 
0e84795
191dbfa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f59fe24
191dbfa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0e84795
 
 
 
 
191dbfa
 
 
 
 
 
 
 
0e84795
 
 
 
 
191dbfa
0e84795
 
191dbfa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
644a908
 
191dbfa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
644a908
0f883aa
191dbfa
 
 
 
 
 
 
 
0f883aa
191dbfa
 
 
 
0f883aa
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
import gradio as gr
import random
import torch
import numpy as np
from PIL import Image, ImageOps
import os
import json
import sys
import multiprocessing
from concurrent.futures import ProcessPoolExecutor
import time

# Assume MagicQuill and other dependencies are present as per user instruction
from MagicQuill import folder_paths
from MagicQuill.llava_new import LLaVAModel
from huggingface_hub import snapshot_download

# Imports for SAM (Only needed in worker process, but imported here for checking)
from segment_anything import sam_model_registry, SamPredictor

# Download models (Main process does this once)
hf_token = os.environ.get("HF_TOKEN")
snapshot_download(repo_id="LiuZichen/MagicQuill-models", repo_type="model", local_dir="models")
snapshot_download(repo_id="LiuZichen/MagicQuillV2-models", repo_type="model", local_dir="models_v2", token=hf_token)

# --- Global Models for Main Process ---
print("Initializing LLaVAModel (Main Process)...")
# LLaVA is stateless/thread-safe enough or too big to duplicate, so we keep it in main process (or use threads)
llavaModel = LLaVAModel()
print("LLaVAModel initialized.")

# --- Worker Process Logic for SAM ---
# Global variable for the worker process to hold its own SAM instance
worker_sam = None

def init_worker_sam(device='cuda'):
    """
    This function is called when a new worker process starts.
    It initializes a standalone SAM model for that process.
    """
    global worker_sam
    print(f"Process {os.getpid()}: Initializing SAM model...")
    
    # Define SAM class locally or import it. Since it was defined in the script, 
    # we can redefine a helper or import the logic. 
    # Ideally, the SAM logic should be in a separate module to be picklable easily.
    # But for this script, we can define the loading logic here.
    
    checkpoint_path = 'models_v2/sam/sam_vit_b_01ec64.pth'
    
    # Load Model
    try:
        sam = sam_model_registry['vit_b'](checkpoint=checkpoint_path)
        sam.to(device=device)
        predictor = SamPredictor(sam)
        
        worker_sam = {
            "predictor": predictor
        }
        print(f"Process {os.getpid()}: SAM initialized.")
    except Exception as e:
        print(f"Process {os.getpid()}: Failed to init SAM: {e}")

def run_sam_inference(image_np, coordinates_positive, coordinates_negative, bboxes):
    """
    The actual inference function running inside the worker process.
    """
    global worker_sam

    if worker_sam is None:
        # Fallback if init didn't run or failed (though ProcessPool initializer should handle it)
        init_worker_sam()
        
    predictor = worker_sam["predictor"]
    
    # Set Image
    predictor.set_image(image_np)
    
    input_point = []
    input_label = []
    
    # Process points
    if coordinates_positive:
        coords = json.loads(coordinates_positive) if isinstance(coordinates_positive, str) else coordinates_positive
        for p in coords:
            input_point.append([p['x'], p['y']])
            input_label.append(1)
            
    if coordinates_negative:
        coords = json.loads(coordinates_negative) if isinstance(coordinates_negative, str) else coordinates_negative
        for p in coords:
            input_point.append([p['x'], p['y']])
            input_label.append(0)

    # Process bbox
    input_box = None
    if bboxes:
        if isinstance(bboxes, str):
            try:
                bboxes = json.loads(bboxes)
            except:
                pass
        
        box_list = []
        if isinstance(bboxes, list):
            for box in bboxes:
                box_list.append(list(box))
        
        if len(box_list) > 0:
            input_box = np.array(box_list)

    if len(input_point) > 0:
        input_point = np.array(input_point)
        input_label = np.array(input_label)
    else:
        input_point = None
        input_label = None

    # Predict
    masks, scores, logits = predictor.predict(
        point_coords=input_point,
        point_labels=input_label,
        box=input_box,
        multimask_output=False,
    )
    
    mask_np = masks[0]
    
    # Post-processing
    # Simply convert mask to uint8 [0, 255] for transport
    if mask_np.dtype == bool:
        mask_np = mask_np.astype(np.uint8) * 255
    else:
        mask_np = (mask_np > 0).astype(np.uint8) * 255
        
    # Return mask as image for client to use
    # We return mask_np twice to satisfy the function signature or unpacker in segment()
    # segment() expects (image_with_alpha_np, mask_np)
    return mask_np, mask_np


# --- Main Process Helpers ---

# We need a pool. Since we are in a script, we initialize it in main block.
sam_pool = None

def numpy_to_tensor(numpy_array):
    tensor = torch.from_numpy(numpy_array).float().unsqueeze(0) / 255.
    return tensor

def guess(original_image, add_color_image, add_edge_mask):
    # LLaVA inference runs in the main process (threaded)
    original_image_tensor = numpy_to_tensor(original_image)
    add_color_image_tensor = numpy_to_tensor(add_color_image)
    add_edge_mask_tensor = numpy_to_tensor(add_edge_mask)

    description, ans1, ans2 = llavaModel.process(original_image_tensor, add_color_image_tensor, add_edge_mask_tensor)

    ans_list = []
    if ans1 and ans1 != "":
        ans_list.append(ans1)
    if ans2 and ans2 != "":
        ans_list.append(ans2)

    return ", ".join(ans_list)

def get_mask_bbox(mask_np):
    # mask_np: [1, H, W] or [H, W]
    if mask_np.ndim == 3:
        mask_np = mask_np[0]
    
    rows = np.any(mask_np, axis=1)
    cols = np.any(mask_np, axis=0)
    if not np.any(rows) or not np.any(cols):
        return None
    
    y_min, y_max = np.where(rows)[0][[0, -1]]
    x_min, x_max = np.where(cols)[0][[0, -1]]
    return int(x_min), int(y_min), int(x_max), int(y_max)

def segment(image, coordinates_positive, coordinates_negative, bboxes):
    # image: numpy array (uint8)
    # Submit task to process pool

    print("image.shape:", image.shape)
    print("coordinates_positive:", coordinates_positive)
    print("coordinates_negative:", coordinates_negative)
    print("bboxes:", bboxes)

    if sam_pool is None:
        return None, json.dumps({'error': 'SAM pool not initialized'})

    # Future result
    future = sam_pool.submit(run_sam_inference, image, coordinates_positive, coordinates_negative, bboxes)
    
    # Wait for result
    image_with_alpha_np, mask_np = future.result(timeout=60) # 60s timeout
    
    # Convert back to PIL for Gradio
    res_pil = Image.fromarray(image_with_alpha_np)
    
    # Calculate bbox
    mask_bbox = get_mask_bbox(mask_np)
    if mask_bbox:
        x_min, y_min, x_max, y_max = mask_bbox
        seg_bbox = {'startX': x_min, 'startY': y_min, 'endX': x_max, 'endY': y_max}
    else:
        seg_bbox = {'startX': 0, 'startY': 0, 'endX': 0, 'endY': 0}
        
    return res_pil, json.dumps(seg_bbox)

# --- Gradio UI ---
with gr.Blocks() as app:
    with gr.Row():
        gr.Markdown("## MagicQuill Worker Server (Draw&Guess + SAM)")
    
    with gr.Tab("Draw & Guess"):
        with gr.Row():
            dg_input_img = gr.Image(label="Original Image")
            dg_color_img = gr.Image(label="Colored Image")
            dg_edge_img = gr.Image(image_mode="L", label="Edge Mask")
        dg_output = gr.Textbox(label="Prediction Output")
        dg_btn = gr.Button("Guess")
        
        dg_btn.click(
            fn=guess,
            inputs=[dg_input_img, dg_color_img, dg_edge_img],
            outputs=dg_output,
            api_name="guess_prompt",
            concurrency_limit=1
        )
        
    with gr.Tab("SAM Segmentation"):
        with gr.Row():
            sam_input_img = gr.Image(label="Input Image", type="numpy")
            sam_pos_coords = gr.Textbox(label="Pos Coords JSON")
            sam_neg_coords = gr.Textbox(label="Neg Coords JSON")
            sam_bboxes = gr.Textbox(label="BBoxes JSON")
        
        with gr.Row():
            sam_output_img = gr.Image(label="Segmented Image", format="png")
            sam_output_bbox = gr.Textbox(label="Mask BBox JSON")
            
        sam_btn = gr.Button("Segment")
        
        sam_btn.click(
            fn=segment,
            inputs=[sam_input_img, sam_pos_coords, sam_neg_coords, sam_bboxes],
            outputs=[sam_output_img, sam_output_bbox],
            api_name="segment",
            concurrency_limit=5
        )

if __name__ == "__main__":
    # Set start method to spawn for CUDA compatibility
    multiprocessing.set_start_method('spawn', force=True)
    
    # Initialize SAM Pool
    # Adjust max_workers based on GPU memory (e.g., 2-4 workers for SAM-B)
    NUM_SAM_WORKERS = 5
    print(f"Starting {NUM_SAM_WORKERS} SAM worker processes...")
    sam_pool = ProcessPoolExecutor(max_workers=NUM_SAM_WORKERS, initializer=init_worker_sam)
    
    # Launch Gradio
    app.queue(max_size=40).launch(max_threads=5)