Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| from diffusers import StableDiffusionUpscalePipeline | |
| from transformers import AutoImageProcessor, Swin2SRForImageSuperResolution | |
| import gc | |
| from PIL import Image | |
| import numpy as np | |
| import logging | |
| import io | |
| import os | |
| import requests | |
| from spandrel import ModelLoader | |
| # Setup logging | |
| log_capture_string = io.StringIO() | |
| ch = logging.StreamHandler(log_capture_string) | |
| ch.setLevel(logging.INFO) | |
| logger = logging.getLogger() | |
| logger.setLevel(logging.INFO) | |
| logger.addHandler(ch) | |
| def get_logs(): | |
| return log_capture_string.getvalue() | |
| # Global models cache | |
| models = {} | |
| def download_file(url, filename): | |
| if not os.path.exists(filename): | |
| logger.info(f"Downloading {filename}...") | |
| response = requests.get(url, stream=True) | |
| with open(filename, 'wb') as f: | |
| for chunk in response.iter_content(chunk_size=8192): | |
| f.write(chunk) | |
| logger.info(f"Downloaded {filename}.") | |
| return filename | |
| def load_realesrgan_x2(): | |
| if "realesrgan_x2" not in models: | |
| logger.info("Loading RealESRGAN x2plus model...") | |
| url = "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth" | |
| model_path = download_file(url, "RealESRGAN_x2plus.pth") | |
| model = ModelLoader().load_from_file(model_path) | |
| model.eval() | |
| # Move to CPU (or CUDA if available, but we focus on CPU here) | |
| device = torch.device("cpu") | |
| model.to(device) | |
| models["realesrgan_x2"] = model | |
| logger.info("RealESRGAN x2plus loaded.") | |
| return models["realesrgan_x2"] | |
| def load_swin2sr_x2(): | |
| if "swin2sr_x2" not in models: | |
| logger.info("Loading Swin2SR x2 model...") | |
| model_id = "caidas/swin2SR-classical-sr-x2-64" | |
| processor = AutoImageProcessor.from_pretrained(model_id) | |
| model = Swin2SRForImageSuperResolution.from_pretrained(model_id) | |
| models["swin2sr_x2"] = (processor, model) | |
| logger.info("Swin2SR x2 loaded.") | |
| return models["swin2sr_x2"] | |
| def load_sd_x4(): | |
| if "sd_x4" not in models: | |
| logger.info("Loading Stable Diffusion x4 model (this might take a while)...") | |
| model_id = "stabilityai/stable-diffusion-x4-upscaler" | |
| pipe = StableDiffusionUpscalePipeline.from_pretrained( | |
| model_id, | |
| torch_dtype=torch.float32, | |
| low_cpu_mem_usage=True | |
| ) | |
| pipe.enable_attention_slicing("max") | |
| pipe.enable_vae_tiling() | |
| models["sd_x4"] = pipe | |
| logger.info("Stable Diffusion x4 loaded.") | |
| return models["sd_x4"] | |
| def upscale_realesrgan(input_img): | |
| model = load_realesrgan_x2() | |
| # Convert PIL to Tensor | |
| img_np = np.array(input_img).astype(np.float32) / 255.0 | |
| img_tensor = torch.from_numpy(img_np).permute(2, 0, 1).unsqueeze(0) | |
| with torch.no_grad(): | |
| output_tensor = model(img_tensor) | |
| # Convert Tensor back to PIL | |
| output_np = output_tensor.squeeze(0).permute(1, 2, 0).clamp(0, 1).numpy() | |
| output_np = (output_np * 255.0).round().astype(np.uint8) | |
| return Image.fromarray(output_np) | |
| def upscale_swin2sr(input_img, scale=2): | |
| processor, model = load_swin2sr_x2() | |
| inputs = processor(images=input_img, return_tensors="pt") | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| output = outputs.reconstruction.data.squeeze().float().cpu().clamp_(0, 1).numpy() | |
| output = np.moveaxis(output, source=0, destination=-1) | |
| output = (output * 255.0).round().astype(np.uint8) | |
| return Image.fromarray(output) | |
| def upscale_diffusion_cpu(input_img, prompt): | |
| pipe = load_sd_x4() | |
| # Resize input if too large to prevent OOM | |
| max_size = 512 | |
| if max(input_img.size) > max_size: | |
| ratio = max_size / max(input_img.size) | |
| new_size = (int(input_img.size[0] * ratio), int(input_img.size[1] * ratio)) | |
| input_img = input_img.resize(new_size, Image.Resampling.LANCZOS) | |
| logger.warning(f"Resized input to {new_size} to prevent OOM") | |
| generator = torch.manual_seed(42) | |
| output = pipe( | |
| prompt=prompt, | |
| image=input_img, | |
| num_inference_steps=20, | |
| guidance_scale=7.0, | |
| generator=generator | |
| ).images[0] | |
| return output | |
| def process_image(input_img, model_name, prompt): | |
| if input_img is None: | |
| return None, get_logs() | |
| logger.info(f"Processing image with {model_name}...") | |
| try: | |
| if model_name == "RealESRGAN x2": | |
| output = upscale_realesrgan(input_img) | |
| elif model_name == "Swin2SR x2": | |
| output = upscale_swin2sr(input_img, scale=2) | |
| elif model_name == "Stable Diffusion x4": | |
| output = upscale_diffusion_cpu(input_img, prompt) | |
| else: | |
| output = input_img # Fallback | |
| gc.collect() | |
| logger.info("Processing complete.") | |
| return output, get_logs() | |
| except Exception as e: | |
| logger.error(f"Error: {str(e)}") | |
| return None, get_logs() | |
| desc = """ | |
| ### Multi-Model Upscaler | |
| Select a model to upscale your image. | |
| * **RealESRGAN x2**: Very fast, sharp results. Best for general photos. | |
| * **Swin2SR x2**: Accurate, good for compressed images. Slower than RealESRGAN. | |
| * **Stable Diffusion x4**: Slow, creative, high memory usage. Adds details but may hallucinate. | |
| """ | |
| with gr.Blocks(title="Universal Upscaler") as iface: | |
| gr.Markdown(desc) | |
| with gr.Row(): | |
| with gr.Column(): | |
| input_image = gr.Image(type="pil", label="Input Image") | |
| model_selector = gr.Dropdown( | |
| choices=["RealESRGAN x2", "Swin2SR x2", "Stable Diffusion x4"], | |
| value="RealESRGAN x2", | |
| label="Select Model" | |
| ) | |
| prompt_input = gr.Textbox( | |
| label="Prompt (for Stable Diffusion only)", | |
| value="highly detailed, 4k, sharp" | |
| ) | |
| submit_btn = gr.Button("Upscale") | |
| with gr.Column(): | |
| output_image = gr.Image(type="pil", label="Upscaled Image") | |
| logs_output = gr.TextArea(label="Logs", interactive=False) | |
| submit_btn.click( | |
| fn=process_image, | |
| inputs=[input_image, model_selector, prompt_input], | |
| outputs=[output_image, logs_output] | |
| ) | |
| iface.launch() |