Spaces:
Running
on
Zero
Running
on
Zero
| """ | |
| Gradio app for VAREdit image editing model. | |
| Provides web interface for editing images with text instructions. | |
| """ | |
| import spaces | |
| import gradio as gr | |
| import os | |
| import tempfile | |
| from PIL import Image | |
| import logging | |
| from infer import load_model, generate_image | |
| import os | |
| from huggingface_hub import snapshot_download | |
| import torch | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| def edit_image( | |
| input_image: Image.Image, | |
| instruction: str, | |
| cfg: float = 4.0, | |
| tau: float = 0.5, | |
| seed: int = -1 | |
| ) -> Image.Image: | |
| """Edit image based on text instruction.""" | |
| if input_image is None: | |
| raise gr.Error("Please upload an image") | |
| if not instruction.strip(): | |
| raise gr.Error("Please provide an editing instruction") | |
| try: | |
| # Load model if needed | |
| # Save input image to temporary file | |
| with tempfile.NamedTemporaryFile(suffix='.jpg', delete=False) as tmp_file: | |
| input_image.save(tmp_file.name, 'JPEG') | |
| temp_path = tmp_file.name | |
| try: | |
| # Generate edited image | |
| result_image = generate_image( | |
| model_components, | |
| temp_path, | |
| instruction, | |
| cfg=cfg, | |
| tau=tau, | |
| seed=seed if seed != -1 else None | |
| ) | |
| return result_image | |
| finally: | |
| # Clean up temporary file | |
| if os.path.exists(temp_path): | |
| os.unlink(temp_path) | |
| except Exception as e: | |
| logger.error(f"Image editing failed: {e}") | |
| raise gr.Error(f"Failed to edit image: {str(e)}") | |
| # Create Gradio interface | |
| def create_interface(): | |
| with gr.Blocks(title="VAREdit Image Editor") as demo: | |
| gr.Markdown("# VAREdit Image Editor") | |
| gr.Markdown("Edit images using natural language instructions with the VAREdit model.") | |
| with gr.Row(): | |
| with gr.Column(): | |
| input_image = gr.Image( | |
| type="pil", | |
| label="Input Image", | |
| ) | |
| instruction = gr.Textbox( | |
| label="Editing Instruction", | |
| placeholder="e.g., 'Remove glasses from this person', 'Change the sky to sunset', 'Add a hat'", | |
| lines=2 | |
| ) | |
| with gr.Accordion("Advanced Settings", open=False): | |
| cfg = gr.Slider( | |
| minimum=1.0, | |
| maximum=10.0, | |
| value=3.0, | |
| step=0.5, | |
| label="CFG Scale (Guidance Strength)" | |
| ) | |
| tau = gr.Slider( | |
| minimum=0.1, | |
| maximum=1.0, | |
| value=0.1, | |
| step=0.01, | |
| label="Temperature (Tau)" | |
| ) | |
| seed = gr.Number( | |
| value=-1, | |
| label="Seed (-1 for random)", | |
| precision=0 | |
| ) | |
| edit_btn = gr.Button("Edit Image", variant="primary", size="lg") | |
| with gr.Column(): | |
| output_image = gr.Image( | |
| label="Edited Image", | |
| ) | |
| # Example images and instructions | |
| gr.Markdown("## Examples") | |
| gr.Examples( | |
| examples=[ | |
| ["assets/test_3.jpg", "change shirt to a black-and-white striped Breton top, add a red beret, set the background to an artist's loft with a window view of the Eiffel Tower"], | |
| ["assets/test.jpg", "Add glasses to this girl and change hair color to red"], | |
| ["assets/test_1.jpg", "replace all the bullets with shimmering, multi-colored butterflies."], | |
| ["assets/test_4.jpg", "Set the scene against a dark, blurred-out server room, make all text and arrows glow with a vibrant cyan light"], | |
| ], | |
| inputs=[input_image, instruction], | |
| outputs=output_image, | |
| fn=lambda img, inst: edit_image(img, inst), | |
| cache_examples=False | |
| ) | |
| # Set up event handler | |
| edit_btn.click( | |
| fn=edit_image, | |
| inputs=[input_image, instruction, cfg, tau, seed], | |
| outputs=output_image | |
| ) | |
| return demo | |
| model_path = "HiDream-ai/VAREdit" | |
| snapshot_download(repo_id=model_path, max_workers=16,repo_type="model", | |
| local_dir=model_path) | |
| model_components = load_model("HiDream-ai/VAREdit", "HiDream-ai/VAREdit/8B-512.pth", "8B", 512) | |
| if __name__ == "__main__": | |
| demo = create_interface() | |
| demo.queue(max_size=50, default_concurrency_limit=16).launch(show_api=False) |