import gradio as gr import torch from diffusers import StableDiffusionUpscalePipeline from transformers import AutoImageProcessor, Swin2SRForImageSuperResolution import gc from PIL import Image import numpy as np import logging import io import os import requests from spandrel import ModelLoader # Setup logging log_capture_string = io.StringIO() ch = logging.StreamHandler(log_capture_string) ch.setLevel(logging.INFO) logger = logging.getLogger() logger.setLevel(logging.INFO) logger.addHandler(ch) def get_logs(): return log_capture_string.getvalue() # Global models cache models = {} def download_file(url, filename): if not os.path.exists(filename): logger.info(f"Downloading {filename}...") response = requests.get(url, stream=True) with open(filename, 'wb') as f: for chunk in response.iter_content(chunk_size=8192): f.write(chunk) logger.info(f"Downloaded {filename}.") return filename def load_realesrgan_x2(): if "realesrgan_x2" not in models: logger.info("Loading RealESRGAN x2plus model...") url = "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth" model_path = download_file(url, "RealESRGAN_x2plus.pth") model = ModelLoader().load_from_file(model_path) model.eval() # Move to CPU (or CUDA if available, but we focus on CPU here) device = torch.device("cpu") model.to(device) models["realesrgan_x2"] = model logger.info("RealESRGAN x2plus loaded.") return models["realesrgan_x2"] def load_swin2sr_x2(): if "swin2sr_x2" not in models: logger.info("Loading Swin2SR x2 model...") model_id = "caidas/swin2SR-classical-sr-x2-64" processor = AutoImageProcessor.from_pretrained(model_id) model = Swin2SRForImageSuperResolution.from_pretrained(model_id) models["swin2sr_x2"] = (processor, model) logger.info("Swin2SR x2 loaded.") return models["swin2sr_x2"] def load_sd_x4(): if "sd_x4" not in models: logger.info("Loading Stable Diffusion x4 model (this might take a while)...") model_id = "stabilityai/stable-diffusion-x4-upscaler" pipe = StableDiffusionUpscalePipeline.from_pretrained( model_id, torch_dtype=torch.float32, low_cpu_mem_usage=True ) pipe.enable_attention_slicing("max") pipe.enable_vae_tiling() models["sd_x4"] = pipe logger.info("Stable Diffusion x4 loaded.") return models["sd_x4"] def upscale_realesrgan(input_img): model = load_realesrgan_x2() # Convert PIL to Tensor img_np = np.array(input_img).astype(np.float32) / 255.0 img_tensor = torch.from_numpy(img_np).permute(2, 0, 1).unsqueeze(0) with torch.no_grad(): output_tensor = model(img_tensor) # Convert Tensor back to PIL output_np = output_tensor.squeeze(0).permute(1, 2, 0).clamp(0, 1).numpy() output_np = (output_np * 255.0).round().astype(np.uint8) return Image.fromarray(output_np) def upscale_swin2sr(input_img, scale=2): processor, model = load_swin2sr_x2() inputs = processor(images=input_img, return_tensors="pt") with torch.no_grad(): outputs = model(**inputs) output = outputs.reconstruction.data.squeeze().float().cpu().clamp_(0, 1).numpy() output = np.moveaxis(output, source=0, destination=-1) output = (output * 255.0).round().astype(np.uint8) return Image.fromarray(output) def upscale_diffusion_cpu(input_img, prompt): pipe = load_sd_x4() # Resize input if too large to prevent OOM max_size = 512 if max(input_img.size) > max_size: ratio = max_size / max(input_img.size) new_size = (int(input_img.size[0] * ratio), int(input_img.size[1] * ratio)) input_img = input_img.resize(new_size, Image.Resampling.LANCZOS) logger.warning(f"Resized input to {new_size} to prevent OOM") generator = torch.manual_seed(42) output = pipe( prompt=prompt, image=input_img, num_inference_steps=20, guidance_scale=7.0, generator=generator ).images[0] return output def process_image(input_img, model_name, prompt): if input_img is None: return None, get_logs() logger.info(f"Processing image with {model_name}...") try: if model_name == "RealESRGAN x2": output = upscale_realesrgan(input_img) elif model_name == "Swin2SR x2": output = upscale_swin2sr(input_img, scale=2) elif model_name == "Stable Diffusion x4": output = upscale_diffusion_cpu(input_img, prompt) else: output = input_img # Fallback gc.collect() logger.info("Processing complete.") return output, get_logs() except Exception as e: logger.error(f"Error: {str(e)}") return None, get_logs() desc = """ ### Multi-Model Upscaler Select a model to upscale your image. * **RealESRGAN x2**: Very fast, sharp results. Best for general photos. * **Swin2SR x2**: Accurate, good for compressed images. Slower than RealESRGAN. * **Stable Diffusion x4**: Slow, creative, high memory usage. Adds details but may hallucinate. """ with gr.Blocks(title="Universal Upscaler") as iface: gr.Markdown(desc) with gr.Row(): with gr.Column(): input_image = gr.Image(type="pil", label="Input Image") model_selector = gr.Dropdown( choices=["RealESRGAN x2", "Swin2SR x2", "Stable Diffusion x4"], value="RealESRGAN x2", label="Select Model" ) prompt_input = gr.Textbox( label="Prompt (for Stable Diffusion only)", value="highly detailed, 4k, sharp" ) submit_btn = gr.Button("Upscale") with gr.Column(): output_image = gr.Image(type="pil", label="Upscaled Image") logs_output = gr.TextArea(label="Logs", interactive=False) submit_btn.click( fn=process_image, inputs=[input_image, model_selector, prompt_input], outputs=[output_image, logs_output] ) iface.launch()