Spaces:
Paused
Paused
| import os | |
| import tempfile | |
| import torch | |
| import numpy as np | |
| import gradio as gr | |
| from PIL import Image | |
| import cv2 | |
| from diffusers import DiffusionPipeline | |
| from script import SatelliteModelGenerator | |
| # Initialize models and device | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| dtype = torch.bfloat16 | |
| # Initialize FLUX model for satellite imagery | |
| flux_pipe = DiffusionPipeline.from_pretrained( | |
| "jbilcke-hf/flux-satellite", | |
| torch_dtype=dtype | |
| ).to(device) | |
| def generate_and_process_map(prompt: str) -> str | None: | |
| """Generate satellite image from prompt and convert to 3D model.""" | |
| try: | |
| # Set dimensions | |
| width = height = 1024 | |
| # Generate random seed | |
| seed = np.random.randint(0, np.iinfo(np.int32).max) | |
| # Set random seeds | |
| torch.manual_seed(seed) | |
| np.random.seed(seed) | |
| # Generate satellite image using FLUX | |
| generator = torch.Generator(device=device).manual_seed(seed) | |
| generated_image = flux_pipe( | |
| prompt=prompt, | |
| width=width, | |
| height=height, | |
| num_inference_steps=30, | |
| generator=generator, | |
| guidance_scale=7.5 | |
| ).images[0] | |
| # Convert PIL Image to OpenCV format | |
| cv_image = cv2.cvtColor(np.array(generated_image), cv2.COLOR_RGB2BGR) | |
| # Initialize SatelliteModelGenerator | |
| generator = SatelliteModelGenerator(building_height=0.09) | |
| # Process image | |
| print("Segmenting image...") | |
| segmented_img = generator.segment_image(cv_image, window_size=5) | |
| print("Estimating heights...") | |
| height_map = generator.estimate_heights(cv_image, segmented_img) | |
| # Generate mesh | |
| print("Generating mesh...") | |
| mesh = generator.generate_mesh(height_map, cv_image, add_walls=True) | |
| # Export to GLB | |
| temp_dir = tempfile.mkdtemp() | |
| output_path = os.path.join(temp_dir, 'output.glb') | |
| mesh.export(output_path) | |
| return output_path | |
| except Exception as e: | |
| print(f"Error during generation: {str(e)}") | |
| import traceback | |
| traceback.print_exc() | |
| return None | |
| # Create Gradio interface | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# Text to Map") | |
| gr.Markdown("Generate 3D maps from text descriptions using FLUX and mesh generation.") | |
| with gr.Row(): | |
| prompt_input = gr.Text( | |
| label="Enter your prompt", | |
| placeholder="eg. satellite view of downtown Manhattan" | |
| ) | |
| with gr.Row(): | |
| generate_btn = gr.Button("Generate", variant="primary") | |
| with gr.Row(): | |
| model_output = gr.Model3D( | |
| label="Generated 3D Map", | |
| clear_color=[0.0, 0.0, 0.0, 0.0], | |
| ) | |
| # Event handler | |
| generate_btn.click( | |
| fn=generate_and_process_map, | |
| inputs=[prompt_input], | |
| outputs=[model_output], | |
| api_name="generate" | |
| ) | |
| if __name__ == "__main__": | |
| demo.queue().launch() |