akhaliq HF Staff commited on
Commit
8c2280a
·
verified ·
1 Parent(s): f33ac27

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -62
app.py CHANGED
@@ -11,20 +11,12 @@ from controlnet_aux.processor import Processor
11
  from PIL import Image
12
  from safetensors.torch import load_file
13
  from huggingface_hub import hf_hub_download, snapshot_download
 
14
 
15
  # Import pipeline and model
16
- # Ensure the videox_fun folder is in your current directory
17
  from videox_fun.pipeline import ZImageControlPipeline
18
  from videox_fun.models import ZImageControlTransformer2DModel
19
 
20
- # Try to import prompt utility, define fallback if missing
21
- try:
22
- from utils.prompt_utils import polish_prompt
23
- except ImportError:
24
- print("utils.prompt_utils not found. Using passthrough for prompt polishing.")
25
- def polish_prompt(prompt):
26
- return prompt
27
-
28
  # --- Configuration & Paths ---
29
  MAX_SEED = np.iinfo(np.int32).max
30
  MAX_IMAGE_SIZE = 1280
@@ -40,20 +32,15 @@ weight_dtype = torch.bfloat16
40
 
41
  # --- FIX: Download Transformer Config & Weights Locally ---
42
  print("Downloading transformer files...")
43
- # This downloads the 'transformer' subfolder to a local cache and returns the path
44
  transformer_path = snapshot_download(
45
  repo_id=MODEL_REPO,
46
  allow_patterns=["transformer/*"],
47
  local_dir="models/transformer",
48
  local_dir_use_symlinks=False
49
  )
50
- # The snapshot puts files in models/transformer/transformer, we need to point to the inner one
51
- # depending on how snapshot_download behaves with 'allow_patterns'.
52
- # Usually it preserves structure. Let's ensure we point to the folder containing config.json.
53
  local_transformer_path = os.path.join(transformer_path, "transformer")
54
 
55
  if not os.path.exists(os.path.join(local_transformer_path, "config.json")):
56
- # Fallback if structure is flat or different
57
  local_transformer_path = transformer_path
58
 
59
  print(f"Transformer files located at: {local_transformer_path}")
@@ -61,7 +48,7 @@ print(f"Transformer files located at: {local_transformer_path}")
61
  # --- 1. Load Transformer ---
62
  print("Initializing Transformer...")
63
  transformer = ZImageControlTransformer2DModel.from_pretrained(
64
- local_transformer_path, # Pass the LOCAL path now
65
  transformer_additional_kwargs={
66
  "control_layers_places": [0, 5, 10, 15, 20, 25],
67
  "control_in_dim": 16
@@ -69,7 +56,6 @@ transformer = ZImageControlTransformer2DModel.from_pretrained(
69
  ).to(device, weight_dtype)
70
 
71
  # --- 2. Download & Load ControlNet Weights ---
72
- # Check if weights exist locally; if not, download them
73
  if not os.path.exists(CONTROLNET_FILENAME):
74
  print(f"Downloading ControlNet weights from {CONTROLNET_REPO}...")
75
  try:
@@ -87,9 +73,7 @@ if CONTROLNET_WEIGHTS:
87
  print(f"Loading ControlNet weights from {CONTROLNET_WEIGHTS}")
88
  try:
89
  state_dict = load_file(CONTROLNET_WEIGHTS)
90
- # Handle potential nesting of state_dict
91
  state_dict = state_dict.get("state_dict", state_dict)
92
-
93
  m, u = transformer.load_state_dict(state_dict, strict=False)
94
  print(f"ControlNet Weights Loaded - Missing keys: {len(m)}, Unexpected keys: {len(u)}")
95
  except Exception as e:
@@ -99,8 +83,6 @@ else:
99
 
100
  # --- 3. Load Core Components ---
101
  print("Loading VAE, Tokenizer, and Text Encoder...")
102
- # These standard libraries usually handle Hub IDs fine, but we can download if they fail too.
103
- # For now, standard diffusers/transformers components usually work with Hub IDs.
104
  vae = AutoencoderKL.from_pretrained(
105
  MODEL_REPO,
106
  subfolder="vae",
@@ -111,6 +93,7 @@ tokenizer = AutoTokenizer.from_pretrained(
111
  subfolder="tokenizer"
112
  )
113
 
 
114
  text_encoder = Qwen3ForCausalLM.from_pretrained(
115
  MODEL_REPO,
116
  subfolder="text_encoder",
@@ -144,11 +127,9 @@ def rescale_image(image, scale, divisible_by=16):
144
  new_width = int(width * scale)
145
  new_height = int(height * scale)
146
 
147
- # Make dimensions divisible by divisible_by
148
  new_width = (new_width // divisible_by) * divisible_by
149
  new_height = (new_height // divisible_by) * divisible_by
150
 
151
- # Clamp to max size
152
  if new_width > MAX_IMAGE_SIZE:
153
  new_width = MAX_IMAGE_SIZE
154
  if new_height > MAX_IMAGE_SIZE:
@@ -157,17 +138,17 @@ def rescale_image(image, scale, divisible_by=16):
157
  resized = image.resize((new_width, new_height), Image.Resampling.LANCZOS)
158
  return resized, new_width, new_height
159
 
160
- def get_image_latent(image, sample_size):
161
  """Convert PIL image to VAE latent representation."""
162
- import torchvision.transforms as transforms
163
-
164
  # Normalize image
165
  transform = transforms.Compose([
166
  transforms.ToTensor(),
167
  transforms.Normalize([0.5], [0.5])
168
  ])
169
 
170
- img_tensor = transform(image).unsqueeze(0).unsqueeze(2) # [B, C, 1, H, W]
 
 
171
  img_tensor = img_tensor.to(device, weight_dtype)
172
 
173
  with torch.no_grad():
@@ -188,36 +169,22 @@ def generate_image(
188
  guidance_scale=1.0,
189
  seed=42,
190
  randomize_seed=True,
191
- is_polish_prompt=True,
192
  progress=gr.Progress(track_tqdm=True)
193
  ):
194
- timestamp = time.time()
195
-
196
  if not prompt.strip():
197
  raise gr.Error("Please enter a prompt to generate an image.")
198
 
199
- # 1. Polish Prompt
200
- final_prompt = prompt
201
- if is_polish_prompt:
202
- progress(0.1, desc="Polishing prompt...")
203
- try:
204
- final_prompt = polish_prompt(prompt)
205
- except Exception as e:
206
- print(f"Prompt polish failed: {e}")
207
- final_prompt = prompt
208
-
209
- # 2. Set Seed
210
  if randomize_seed:
211
  seed = random.randint(0, MAX_SEED)
212
  generator = torch.Generator(device).manual_seed(seed)
213
 
214
- # 3. Process Control Image
215
  if input_image is None:
216
  raise gr.Error("Please upload a control image.")
217
 
218
  progress(0.2, desc=f"Processing {control_mode}...")
219
 
220
- # Map control mode to processor ID
221
  processor_map = {
222
  'Canny': 'canny',
223
  'HED': 'softedge_hed',
@@ -227,34 +194,30 @@ def generate_image(
227
  }
228
  processor_id = processor_map.get(control_mode, 'canny')
229
 
230
- # Initialize processor
231
  try:
232
  processor = Processor(processor_id)
233
  except Exception as e:
234
  print(f"Failed to load processor {processor_id}, falling back to Canny. Error: {e}")
235
  processor = Processor('canny')
236
 
237
- # Resize input for processing
238
  control_image_rescaled, width, height = rescale_image(input_image, image_scale, 16)
239
 
240
- # Run Processor (requires resizing to 1024x1024 typically for best results with these models, then back)
241
  temp_image = control_image_rescaled.resize((1024, 1024))
242
  processed_image_pil = processor(temp_image, to_pil=True)
243
  processed_image_pil = processed_image_pil.resize((width, height))
244
 
245
  # Convert to Latent
246
  progress(0.4, desc="Encoding control image...")
247
- control_image_latent = get_image_latent(
248
- processed_image_pil,
249
- sample_size=[height, width]
250
- )[:, :, 0]
251
 
252
- # 4. Generate
253
  progress(0.5, desc="Generating...")
254
 
255
  try:
256
  result = pipe(
257
- prompt=final_prompt,
258
  negative_prompt=negative_prompt,
259
  height=height,
260
  width=width,
@@ -268,7 +231,7 @@ def generate_image(
268
  image = result.images[0]
269
  progress(1.0, desc="Complete!")
270
 
271
- return image, seed, processed_image_pil, final_prompt
272
 
273
  except Exception as e:
274
  raise gr.Error(f"Generation failed: {str(e)}")
@@ -320,13 +283,13 @@ button.primary:hover {
320
  }
321
  """
322
 
323
- with gr.Blocks(title="Z-Image Turbo ControlNet") as demo:
324
 
325
  gr.HTML("""
326
  <div class="header-container">
327
  <div class="info-badge">✓ ControlNet Union</div>
328
  <h1 class="main-title">Z-Image Turbo</h1>
329
- <p class="subtitle">Multi-Control Generation with LLM Prompt Polishing</p>
330
  </div>
331
  """)
332
 
@@ -339,9 +302,7 @@ with gr.Blocks(title="Z-Image Turbo ControlNet") as demo:
339
  lines=3
340
  )
341
 
342
- with gr.Row():
343
- is_polish_prompt = gr.Checkbox(label="Polish Prompt with LLM", value=True)
344
- randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
345
 
346
  negative_prompt = gr.Textbox(
347
  label="Negative Prompt",
@@ -381,7 +342,6 @@ with gr.Blocks(title="Z-Image Turbo ControlNet") as demo:
381
  output_image = gr.Image(label="Generated Image", type="pil")
382
 
383
  with gr.Accordion("Details & Debug", open=True):
384
- polished_prompt_output = gr.Textbox(label="Actual Polished Prompt", interactive=False, lines=2)
385
  with gr.Row():
386
  seed_output = gr.Number(label="Seed Used", precision=0)
387
  control_output = gr.Image(label="Preprocessor Output", type="pil")
@@ -399,11 +359,10 @@ with gr.Blocks(title="Z-Image Turbo ControlNet") as demo:
399
  inputs=[
400
  prompt, negative_prompt, input_image, control_mode,
401
  control_context_scale, image_scale, num_inference_steps,
402
- guidance_scale, seed, randomize_seed, is_polish_prompt
403
  ],
404
- outputs=[output_image, seed_output, control_output, polished_prompt_output]
405
  )
406
 
407
  if __name__ == "__main__":
408
- demo.launch(share=False,
409
- css=apple_css)
 
11
  from PIL import Image
12
  from safetensors.torch import load_file
13
  from huggingface_hub import hf_hub_download, snapshot_download
14
+ import torchvision.transforms as transforms
15
 
16
  # Import pipeline and model
 
17
  from videox_fun.pipeline import ZImageControlPipeline
18
  from videox_fun.models import ZImageControlTransformer2DModel
19
 
 
 
 
 
 
 
 
 
20
  # --- Configuration & Paths ---
21
  MAX_SEED = np.iinfo(np.int32).max
22
  MAX_IMAGE_SIZE = 1280
 
32
 
33
  # --- FIX: Download Transformer Config & Weights Locally ---
34
  print("Downloading transformer files...")
 
35
  transformer_path = snapshot_download(
36
  repo_id=MODEL_REPO,
37
  allow_patterns=["transformer/*"],
38
  local_dir="models/transformer",
39
  local_dir_use_symlinks=False
40
  )
 
 
 
41
  local_transformer_path = os.path.join(transformer_path, "transformer")
42
 
43
  if not os.path.exists(os.path.join(local_transformer_path, "config.json")):
 
44
  local_transformer_path = transformer_path
45
 
46
  print(f"Transformer files located at: {local_transformer_path}")
 
48
  # --- 1. Load Transformer ---
49
  print("Initializing Transformer...")
50
  transformer = ZImageControlTransformer2DModel.from_pretrained(
51
+ local_transformer_path,
52
  transformer_additional_kwargs={
53
  "control_layers_places": [0, 5, 10, 15, 20, 25],
54
  "control_in_dim": 16
 
56
  ).to(device, weight_dtype)
57
 
58
  # --- 2. Download & Load ControlNet Weights ---
 
59
  if not os.path.exists(CONTROLNET_FILENAME):
60
  print(f"Downloading ControlNet weights from {CONTROLNET_REPO}...")
61
  try:
 
73
  print(f"Loading ControlNet weights from {CONTROLNET_WEIGHTS}")
74
  try:
75
  state_dict = load_file(CONTROLNET_WEIGHTS)
 
76
  state_dict = state_dict.get("state_dict", state_dict)
 
77
  m, u = transformer.load_state_dict(state_dict, strict=False)
78
  print(f"ControlNet Weights Loaded - Missing keys: {len(m)}, Unexpected keys: {len(u)}")
79
  except Exception as e:
 
83
 
84
  # --- 3. Load Core Components ---
85
  print("Loading VAE, Tokenizer, and Text Encoder...")
 
 
86
  vae = AutoencoderKL.from_pretrained(
87
  MODEL_REPO,
88
  subfolder="vae",
 
93
  subfolder="tokenizer"
94
  )
95
 
96
+ # Qwen3ForCausalLM is still needed as the Text Encoder for the pipeline
97
  text_encoder = Qwen3ForCausalLM.from_pretrained(
98
  MODEL_REPO,
99
  subfolder="text_encoder",
 
127
  new_width = int(width * scale)
128
  new_height = int(height * scale)
129
 
 
130
  new_width = (new_width // divisible_by) * divisible_by
131
  new_height = (new_height // divisible_by) * divisible_by
132
 
 
133
  if new_width > MAX_IMAGE_SIZE:
134
  new_width = MAX_IMAGE_SIZE
135
  if new_height > MAX_IMAGE_SIZE:
 
138
  resized = image.resize((new_width, new_height), Image.Resampling.LANCZOS)
139
  return resized, new_width, new_height
140
 
141
+ def get_image_latent(image):
142
  """Convert PIL image to VAE latent representation."""
 
 
143
  # Normalize image
144
  transform = transforms.Compose([
145
  transforms.ToTensor(),
146
  transforms.Normalize([0.5], [0.5])
147
  ])
148
 
149
+ # FIX: Only unsqueeze(0) for Batch dimension [B, C, H, W]
150
+ # Removed the second unsqueeze(2) which caused the 5D error
151
+ img_tensor = transform(image).unsqueeze(0)
152
  img_tensor = img_tensor.to(device, weight_dtype)
153
 
154
  with torch.no_grad():
 
169
  guidance_scale=1.0,
170
  seed=42,
171
  randomize_seed=True,
 
172
  progress=gr.Progress(track_tqdm=True)
173
  ):
 
 
174
  if not prompt.strip():
175
  raise gr.Error("Please enter a prompt to generate an image.")
176
 
177
+ # 1. Set Seed
 
 
 
 
 
 
 
 
 
 
178
  if randomize_seed:
179
  seed = random.randint(0, MAX_SEED)
180
  generator = torch.Generator(device).manual_seed(seed)
181
 
182
+ # 2. Process Control Image
183
  if input_image is None:
184
  raise gr.Error("Please upload a control image.")
185
 
186
  progress(0.2, desc=f"Processing {control_mode}...")
187
 
 
188
  processor_map = {
189
  'Canny': 'canny',
190
  'HED': 'softedge_hed',
 
194
  }
195
  processor_id = processor_map.get(control_mode, 'canny')
196
 
 
197
  try:
198
  processor = Processor(processor_id)
199
  except Exception as e:
200
  print(f"Failed to load processor {processor_id}, falling back to Canny. Error: {e}")
201
  processor = Processor('canny')
202
 
 
203
  control_image_rescaled, width, height = rescale_image(input_image, image_scale, 16)
204
 
205
+ # Run Processor
206
  temp_image = control_image_rescaled.resize((1024, 1024))
207
  processed_image_pil = processor(temp_image, to_pil=True)
208
  processed_image_pil = processed_image_pil.resize((width, height))
209
 
210
  # Convert to Latent
211
  progress(0.4, desc="Encoding control image...")
212
+ # FIX: Passed result directly without sample_size args which aren't used in new function
213
+ control_image_latent = get_image_latent(processed_image_pil)
 
 
214
 
215
+ # 3. Generate
216
  progress(0.5, desc="Generating...")
217
 
218
  try:
219
  result = pipe(
220
+ prompt=prompt,
221
  negative_prompt=negative_prompt,
222
  height=height,
223
  width=width,
 
231
  image = result.images[0]
232
  progress(1.0, desc="Complete!")
233
 
234
+ return image, seed, processed_image_pil
235
 
236
  except Exception as e:
237
  raise gr.Error(f"Generation failed: {str(e)}")
 
283
  }
284
  """
285
 
286
+ with gr.Blocks(title="Z-Image Turbo ControlNet", css=apple_css) as demo:
287
 
288
  gr.HTML("""
289
  <div class="header-container">
290
  <div class="info-badge">✓ ControlNet Union</div>
291
  <h1 class="main-title">Z-Image Turbo</h1>
292
+ <p class="subtitle">Multi-Control Generation</p>
293
  </div>
294
  """)
295
 
 
302
  lines=3
303
  )
304
 
305
+ randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
 
 
306
 
307
  negative_prompt = gr.Textbox(
308
  label="Negative Prompt",
 
342
  output_image = gr.Image(label="Generated Image", type="pil")
343
 
344
  with gr.Accordion("Details & Debug", open=True):
 
345
  with gr.Row():
346
  seed_output = gr.Number(label="Seed Used", precision=0)
347
  control_output = gr.Image(label="Preprocessor Output", type="pil")
 
359
  inputs=[
360
  prompt, negative_prompt, input_image, control_mode,
361
  control_context_scale, image_scale, num_inference_steps,
362
+ guidance_scale, seed, randomize_seed
363
  ],
364
+ outputs=[output_image, seed_output, control_output]
365
  )
366
 
367
  if __name__ == "__main__":
368
+ demo.launch(share=False)