Spaces:
Running
on
Zero
Running
on
Zero
| import spaces | |
| import gradio as gr | |
| import torch | |
| import numpy as np | |
| import random | |
| import time | |
| import os | |
| from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler | |
| from transformers import AutoTokenizer, Qwen3ForCausalLM | |
| from controlnet_aux.processor import Processor | |
| from PIL import Image | |
| from safetensors.torch import load_file | |
| from huggingface_hub import hf_hub_download, snapshot_download | |
| import torchvision.transforms as transforms | |
| # Import pipeline and model | |
| from videox_fun.pipeline import ZImageControlPipeline | |
| from videox_fun.models import ZImageControlTransformer2DModel | |
| # --- Configuration & Paths --- | |
| MAX_SEED = np.iinfo(np.int32).max | |
| MAX_IMAGE_SIZE = 1280 | |
| # Hugging Face Repo IDs | |
| MODEL_REPO = "Tongyi-MAI/Z-Image-Turbo" | |
| CONTROLNET_REPO = "alibaba-pai/Z-Image-Turbo-Fun-Controlnet-Union" | |
| CONTROLNET_FILENAME = "Z-Image-Turbo-Fun-Controlnet-Union.safetensors" | |
| print(f"Loading Z-Image Turbo from {MODEL_REPO}...") | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| weight_dtype = torch.bfloat16 | |
| # --- FIX: Download Transformer Config & Weights Locally --- | |
| print("Downloading transformer files...") | |
| transformer_path = snapshot_download( | |
| repo_id=MODEL_REPO, | |
| allow_patterns=["transformer/*"], | |
| local_dir="models/transformer", | |
| local_dir_use_symlinks=False | |
| ) | |
| local_transformer_path = os.path.join(transformer_path, "transformer") | |
| if not os.path.exists(os.path.join(local_transformer_path, "config.json")): | |
| local_transformer_path = transformer_path | |
| print(f"Transformer files located at: {local_transformer_path}") | |
| # --- 1. Load Transformer --- | |
| print("Initializing Transformer...") | |
| transformer = ZImageControlTransformer2DModel.from_pretrained( | |
| local_transformer_path, | |
| transformer_additional_kwargs={ | |
| "control_layers_places": [0, 5, 10, 15, 20, 25], | |
| "control_in_dim": 16 | |
| }, | |
| ).to(device, weight_dtype) | |
| # --- 2. Download & Load ControlNet Weights --- | |
| if not os.path.exists(CONTROLNET_FILENAME): | |
| print(f"Downloading ControlNet weights from {CONTROLNET_REPO}...") | |
| try: | |
| CONTROLNET_WEIGHTS = hf_hub_download( | |
| repo_id=CONTROLNET_REPO, | |
| filename=CONTROLNET_FILENAME | |
| ) | |
| except Exception as e: | |
| print(f"Failed to download ControlNet weights: {e}") | |
| CONTROLNET_WEIGHTS = None | |
| else: | |
| CONTROLNET_WEIGHTS = CONTROLNET_FILENAME | |
| if CONTROLNET_WEIGHTS: | |
| print(f"Loading ControlNet weights from {CONTROLNET_WEIGHTS}") | |
| try: | |
| state_dict = load_file(CONTROLNET_WEIGHTS) | |
| state_dict = state_dict.get("state_dict", state_dict) | |
| m, u = transformer.load_state_dict(state_dict, strict=False) | |
| print(f"ControlNet Weights Loaded - Missing keys: {len(m)}, Unexpected keys: {len(u)}") | |
| except Exception as e: | |
| print(f"Error loading ControlNet weights: {e}") | |
| else: | |
| print("Warning: Running without explicit ControlNet weights.") | |
| # --- 3. Load Core Components --- | |
| print("Loading VAE, Tokenizer, and Text Encoder...") | |
| vae = AutoencoderKL.from_pretrained( | |
| MODEL_REPO, | |
| subfolder="vae", | |
| ).to(device, weight_dtype) | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| MODEL_REPO, | |
| subfolder="tokenizer" | |
| ) | |
| # Qwen3ForCausalLM is still needed as the Text Encoder for the pipeline | |
| text_encoder = Qwen3ForCausalLM.from_pretrained( | |
| MODEL_REPO, | |
| subfolder="text_encoder", | |
| torch_dtype=weight_dtype, | |
| ).to(device) | |
| scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained( | |
| MODEL_REPO, | |
| subfolder="scheduler" | |
| ) | |
| # --- 4. Assemble Pipeline --- | |
| pipe = ZImageControlPipeline( | |
| vae=vae, | |
| tokenizer=tokenizer, | |
| text_encoder=text_encoder, | |
| transformer=transformer, | |
| scheduler=scheduler, | |
| ) | |
| pipe.to(device, weight_dtype) | |
| print(f"Model loaded successfully on {device}!") | |
| # --- Helper Functions --- | |
| def rescale_image(image, scale, divisible_by=16): | |
| """Rescale image and ensure dimensions are divisible by specified value.""" | |
| if image is None: | |
| return None, 1024, 1024 | |
| width, height = image.size | |
| new_width = int(width * scale) | |
| new_height = int(height * scale) | |
| new_width = (new_width // divisible_by) * divisible_by | |
| new_height = (new_height // divisible_by) * divisible_by | |
| if new_width > MAX_IMAGE_SIZE: | |
| new_width = MAX_IMAGE_SIZE | |
| if new_height > MAX_IMAGE_SIZE: | |
| new_height = MAX_IMAGE_SIZE | |
| resized = image.resize((new_width, new_height), Image.Resampling.LANCZOS) | |
| return resized, new_width, new_height | |
| def get_image_latent(image): | |
| """Convert PIL image to VAE latent representation.""" | |
| # Normalize image | |
| transform = transforms.Compose([ | |
| transforms.ToTensor(), | |
| transforms.Normalize([0.5], [0.5]) | |
| ]) | |
| # FIX: Only unsqueeze(0) for Batch dimension [B, C, H, W] | |
| # Removed the second unsqueeze(2) which caused the 5D error | |
| img_tensor = transform(image).unsqueeze(0) | |
| img_tensor = img_tensor.to(device, weight_dtype) | |
| with torch.no_grad(): | |
| latent = pipe.vae.encode(img_tensor).latent_dist.sample() | |
| latent = latent * pipe.vae.config.scaling_factor | |
| return latent | |
| def generate_image( | |
| prompt, | |
| negative_prompt="blurry, ugly, bad quality", | |
| input_image=None, | |
| control_mode="Canny", | |
| control_context_scale=0.75, | |
| image_scale=1.0, | |
| num_inference_steps=9, | |
| guidance_scale=1.0, | |
| seed=42, | |
| randomize_seed=True, | |
| progress=gr.Progress(track_tqdm=True) | |
| ): | |
| if not prompt.strip(): | |
| raise gr.Error("Please enter a prompt to generate an image.") | |
| # 1. Set Seed | |
| if randomize_seed: | |
| seed = random.randint(0, MAX_SEED) | |
| generator = torch.Generator(device).manual_seed(seed) | |
| # 2. Process Control Image | |
| if input_image is None: | |
| raise gr.Error("Please upload a control image.") | |
| progress(0.2, desc=f"Processing {control_mode}...") | |
| processor_map = { | |
| 'Canny': 'canny', | |
| 'HED': 'softedge_hed', | |
| 'Depth': 'depth_midas', | |
| 'MLSD': 'mlsd', | |
| 'Pose': 'openpose_full' | |
| } | |
| processor_id = processor_map.get(control_mode, 'canny') | |
| try: | |
| processor = Processor(processor_id) | |
| except Exception as e: | |
| print(f"Failed to load processor {processor_id}, falling back to Canny. Error: {e}") | |
| processor = Processor('canny') | |
| control_image_rescaled, width, height = rescale_image(input_image, image_scale, 16) | |
| # Run Processor | |
| temp_image = control_image_rescaled.resize((1024, 1024)) | |
| processed_image_pil = processor(temp_image, to_pil=True) | |
| processed_image_pil = processed_image_pil.resize((width, height)) | |
| # Convert to Latent | |
| progress(0.4, desc="Encoding control image...") | |
| # FIX: Passed result directly without sample_size args which aren't used in new function | |
| control_image_latent = get_image_latent(processed_image_pil) | |
| # 3. Generate | |
| progress(0.5, desc="Generating...") | |
| try: | |
| result = pipe( | |
| prompt=prompt, | |
| negative_prompt=negative_prompt, | |
| height=height, | |
| width=width, | |
| generator=generator, | |
| guidance_scale=guidance_scale, | |
| control_image=control_image_latent, | |
| num_inference_steps=num_inference_steps, | |
| control_context_scale=control_context_scale, | |
| ) | |
| image = result.images[0] | |
| progress(1.0, desc="Complete!") | |
| return image, seed, processed_image_pil | |
| except Exception as e: | |
| raise gr.Error(f"Generation failed: {str(e)}") | |
| # --- UI Configuration (Apple Style) --- | |
| apple_css = """ | |
| .gradio-container { | |
| max-width: 1200px !important; | |
| margin: 0 auto !important; | |
| padding: 48px 20px !important; | |
| font-family: -apple-system, BlinkMacSystemFont, 'Inter', 'Segoe UI', sans-serif !important; | |
| } | |
| .header-container { text-align: center; margin-bottom: 48px; } | |
| .main-title { | |
| font-size: 56px !important; font-weight: 600 !important; | |
| letter-spacing: -0.02em !important; color: #1d1d1f !important; | |
| margin: 0 0 12px 0 !important; | |
| } | |
| .subtitle { | |
| font-size: 21px !important; color: #6e6e73 !important; | |
| margin: 0 0 24px 0 !important; | |
| } | |
| .info-badge { | |
| display: inline-block; background: #0071e3; color: white; | |
| padding: 6px 16px; border-radius: 20px; font-size: 14px; | |
| font-weight: 500; margin-bottom: 16px; | |
| } | |
| textarea { | |
| font-size: 17px !important; border-radius: 12px !important; | |
| border: 1px solid #d2d2d7 !important; padding: 12px 16px !important; | |
| } | |
| textarea:focus { | |
| border-color: #0071e3 !important; box-shadow: 0 0 0 4px rgba(0, 113, 227, 0.15) !important; | |
| outline: none !important; | |
| } | |
| button.primary { | |
| font-size: 17px !important; padding: 12px 32px !important; | |
| border-radius: 980px !important; background: #0071e3 !important; | |
| border: none !important; color: #ffffff !important; | |
| transition: all 0.2s ease !important; | |
| } | |
| button.primary:hover { | |
| background: #0077ed !important; transform: scale(1.02) !important; | |
| } | |
| .footer-text { | |
| text-align: center; margin-top: 48px; font-size: 14px !important; | |
| color: #86868b !important; | |
| } | |
| """ | |
| with gr.Blocks(title="Z-Image Turbo ControlNet") as demo: | |
| gr.HTML(""" | |
| <div class="header-container"> | |
| <div class="info-badge">✓ ControlNet Union</div> | |
| <h1 class="main-title">Z-Image Turbo</h1> | |
| <p class="subtitle">Multi-Control Generation</p> | |
| </div> | |
| """) | |
| with gr.Row(): | |
| # Left Input Column | |
| with gr.Column(scale=1): | |
| prompt = gr.Textbox( | |
| label="Prompt", | |
| placeholder="Describe the image you want to create...", | |
| lines=3 | |
| ) | |
| randomize_seed = gr.Checkbox(label="Randomize Seed", value=True) | |
| negative_prompt = gr.Textbox( | |
| label="Negative Prompt", | |
| value="blurry, ugly, bad quality", | |
| lines=1 | |
| ) | |
| input_image = gr.Image( | |
| label="Control Image (Required)", | |
| type="pil", | |
| sources=['upload', 'clipboard'], | |
| height=300 | |
| ) | |
| control_mode = gr.Radio( | |
| choices=["Canny", "Depth", "HED", "MLSD", "Pose"], | |
| value="Canny", | |
| label="Control Mode", | |
| info="Select the type of structure to extract" | |
| ) | |
| with gr.Accordion("Advanced Settings", open=False): | |
| with gr.Row(): | |
| num_inference_steps = gr.Slider(label="Steps", minimum=1, maximum=30, step=1, value=9) | |
| guidance_scale = gr.Slider(label="Guidance", minimum=0.0, maximum=10.0, step=0.1, value=1.0) | |
| with gr.Row(): | |
| control_context_scale = gr.Slider(label="Control Strength", minimum=0.0, maximum=1.0, step=0.01, value=0.75) | |
| image_scale = gr.Slider(label="Image Scale", minimum=0.5, maximum=2.0, step=0.1, value=1.0) | |
| seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=42) | |
| generate_btn = gr.Button("Generate Image", variant="primary", elem_classes="primary") | |
| # Right Output Column | |
| with gr.Column(scale=1): | |
| output_image = gr.Image(label="Generated Image", type="pil") | |
| with gr.Accordion("Details & Debug", open=True): | |
| with gr.Row(): | |
| seed_output = gr.Number(label="Seed Used", precision=0) | |
| control_output = gr.Image(label="Preprocessor Output", type="pil") | |
| # Footer | |
| gr.HTML(""" | |
| <div class="footer-text"> | |
| Powered by Z-Image Turbo • VideoX-Fun • Tongyi-MAI | |
| </div> | |
| """) | |
| # Event Wiring | |
| generate_btn.click( | |
| fn=generate_image, | |
| inputs=[ | |
| prompt, negative_prompt, input_image, control_mode, | |
| control_context_scale, image_scale, num_inference_steps, | |
| guidance_scale, seed, randomize_seed | |
| ], | |
| outputs=[output_image, seed_output, control_output] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch(share=False, | |
| css=apple_css) |