import spaces import gradio as gr import torch import numpy as np import random import time import os from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler from transformers import AutoTokenizer, Qwen3ForCausalLM from controlnet_aux.processor import Processor from PIL import Image from safetensors.torch import load_file from huggingface_hub import hf_hub_download, snapshot_download # Import pipeline and model from videox_fun.pipeline import ZImageControlPipeline from videox_fun.models import ZImageControlTransformer2DModel # --- Configuration & Paths --- MAX_SEED = np.iinfo(np.int32).max MAX_IMAGE_SIZE = 1280 # Hugging Face Repo IDs MODEL_REPO = "Tongyi-MAI/Z-Image-Turbo" CONTROLNET_REPO = "alibaba-pai/Z-Image-Turbo-Fun-Controlnet-Union" CONTROLNET_FILENAME = "Z-Image-Turbo-Fun-Controlnet-Union.safetensors" print(f"Loading Z-Image Turbo from {MODEL_REPO}...") device = "cuda" if torch.cuda.is_available() else "cpu" weight_dtype = torch.bfloat16 # --- FIX: Download Transformer Config & Weights Locally --- print("Downloading transformer files...") transformer_path = snapshot_download( repo_id=MODEL_REPO, allow_patterns=["transformer/*"], local_dir="models/transformer", local_dir_use_symlinks=False ) local_transformer_path = os.path.join(transformer_path, "transformer") if not os.path.exists(os.path.join(local_transformer_path, "config.json")): local_transformer_path = transformer_path print(f"Transformer files located at: {local_transformer_path}") # --- 1. Load Transformer --- print("Initializing Transformer...") transformer = ZImageControlTransformer2DModel.from_pretrained( local_transformer_path, transformer_additional_kwargs={ "control_layers_places": [0, 5, 10, 15, 20, 25], "control_in_dim": 16 }, ).to(device, weight_dtype) # --- 2. Download & Load ControlNet Weights --- if not os.path.exists(CONTROLNET_FILENAME): print(f"Downloading ControlNet weights from {CONTROLNET_REPO}...") try: CONTROLNET_WEIGHTS = hf_hub_download( repo_id=CONTROLNET_REPO, filename=CONTROLNET_FILENAME ) except Exception as e: print(f"Failed to download ControlNet weights: {e}") CONTROLNET_WEIGHTS = None else: CONTROLNET_WEIGHTS = CONTROLNET_FILENAME if CONTROLNET_WEIGHTS: print(f"Loading ControlNet weights from {CONTROLNET_WEIGHTS}") try: state_dict = load_file(CONTROLNET_WEIGHTS) state_dict = state_dict.get("state_dict", state_dict) m, u = transformer.load_state_dict(state_dict, strict=False) print(f"ControlNet Weights Loaded - Missing keys: {len(m)}, Unexpected keys: {len(u)}") except Exception as e: print(f"Error loading ControlNet weights: {e}") else: print("Warning: Running without explicit ControlNet weights.") # --- 3. Load Core Components --- print("Loading VAE, Tokenizer, and Text Encoder...") vae = AutoencoderKL.from_pretrained( MODEL_REPO, subfolder="vae", ).to(device, weight_dtype) tokenizer = AutoTokenizer.from_pretrained( MODEL_REPO, subfolder="tokenizer" ) text_encoder = Qwen3ForCausalLM.from_pretrained( MODEL_REPO, subfolder="text_encoder", torch_dtype=weight_dtype, ).to(device) scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained( MODEL_REPO, subfolder="scheduler" ) # --- 4. Assemble Pipeline --- pipe = ZImageControlPipeline( vae=vae, tokenizer=tokenizer, text_encoder=text_encoder, transformer=transformer, scheduler=scheduler, ) pipe.to(device, weight_dtype) print(f"Model loaded successfully on {device}!") # --- Helper Functions --- def rescale_image(image, scale, divisible_by=16): """Rescale image and ensure dimensions are divisible by specified value.""" if image is None: return None, 1024, 1024 width, height = image.size new_width = int(width * scale) new_height = int(height * scale) new_width = (new_width // divisible_by) * divisible_by new_height = (new_height // divisible_by) * divisible_by if new_width > MAX_IMAGE_SIZE: new_width = MAX_IMAGE_SIZE if new_height > MAX_IMAGE_SIZE: new_height = MAX_IMAGE_SIZE resized = image.resize((new_width, new_height), Image.Resampling.LANCZOS) return resized, new_width, new_height @spaces.GPU() def generate_image( prompt, negative_prompt="blurry, ugly, bad quality", input_image=None, control_mode="Canny", control_context_scale=0.75, image_scale=1.0, num_inference_steps=9, guidance_scale=1.0, seed=42, randomize_seed=True, progress=gr.Progress(track_tqdm=True) ): if not prompt.strip(): raise gr.Error("Please enter a prompt to generate an image.") # 1. Set Seed if randomize_seed: seed = random.randint(0, MAX_SEED) generator = torch.Generator(device).manual_seed(seed) # 2. Process Control Image if input_image is None: raise gr.Error("Please upload a control image.") progress(0.2, desc=f"Processing {control_mode}...") processor_map = { 'Canny': 'canny', 'HED': 'softedge_hed', 'Depth': 'depth_midas', 'MLSD': 'mlsd', 'Pose': 'openpose_full' } processor_id = processor_map.get(control_mode, 'canny') try: processor = Processor(processor_id) except Exception as e: print(f"Failed to load processor {processor_id}, falling back to Canny. Error: {e}") processor = Processor('canny') control_image_rescaled, width, height = rescale_image(input_image, image_scale, 16) # Run Processor # We resize to 1024 temporarily for the preprocessor to work best, then resize back to target temp_image = control_image_rescaled.resize((1024, 1024)) processed_image_pil = processor(temp_image, to_pil=True) processed_image_pil = processed_image_pil.resize((width, height)) # 3. Generate progress(0.5, desc="Generating...") try: # FIX: Pass the processed PIL image directly. # The pipeline handles VAE encoding internally. result = pipe( prompt=prompt, negative_prompt=negative_prompt, height=height, width=width, generator=generator, guidance_scale=guidance_scale, control_image=processed_image_pil, num_inference_steps=num_inference_steps, control_context_scale=control_context_scale, ) image = result.images[0] progress(1.0, desc="Complete!") return image, seed, processed_image_pil except Exception as e: raise gr.Error(f"Generation failed: {str(e)}") # --- UI Configuration (Apple Style) --- apple_css = """ .gradio-container { max-width: 1200px !important; margin: 0 auto !important; padding: 48px 20px !important; font-family: -apple-system, BlinkMacSystemFont, 'Inter', 'Segoe UI', sans-serif !important; } .header-container { text-align: center; margin-bottom: 48px; } .main-title { font-size: 56px !important; font-weight: 600 !important; letter-spacing: -0.02em !important; color: #1d1d1f !important; margin: 0 0 12px 0 !important; } .subtitle { font-size: 21px !important; color: #6e6e73 !important; margin: 0 0 24px 0 !important; } .info-badge { display: inline-block; background: #0071e3; color: white; padding: 6px 16px; border-radius: 20px; font-size: 14px; font-weight: 500; margin-bottom: 16px; } textarea { font-size: 17px !important; border-radius: 12px !important; border: 1px solid #d2d2d7 !important; padding: 12px 16px !important; } textarea:focus { border-color: #0071e3 !important; box-shadow: 0 0 0 4px rgba(0, 113, 227, 0.15) !important; outline: none !important; } button.primary { font-size: 17px !important; padding: 12px 32px !important; border-radius: 980px !important; background: #0071e3 !important; border: none !important; color: #ffffff !important; transition: all 0.2s ease !important; } button.primary:hover { background: #0077ed !important; transform: scale(1.02) !important; } .footer-text { text-align: center; margin-top: 48px; font-size: 14px !important; color: #86868b !important; } """ with gr.Blocks(title="Z-Image Turbo ControlNet") as demo: gr.HTML("""
Multi-Control Generation