jgitsolutions's picture
Upload 2 files
b2759ab verified
raw
history blame
6.52 kB
import gradio as gr
import torch
from diffusers import StableDiffusionUpscalePipeline
from transformers import AutoImageProcessor, Swin2SRForImageSuperResolution
import gc
from PIL import Image
import numpy as np
import logging
import io
import os
import requests
from spandrel import ModelLoader
# Setup logging
log_capture_string = io.StringIO()
ch = logging.StreamHandler(log_capture_string)
ch.setLevel(logging.INFO)
logger = logging.getLogger()
logger.setLevel(logging.INFO)
logger.addHandler(ch)
def get_logs():
return log_capture_string.getvalue()
# Global models cache
models = {}
def download_file(url, filename):
if not os.path.exists(filename):
logger.info(f"Downloading {filename}...")
response = requests.get(url, stream=True)
with open(filename, 'wb') as f:
for chunk in response.iter_content(chunk_size=8192):
f.write(chunk)
logger.info(f"Downloaded {filename}.")
return filename
def load_realesrgan_x2():
if "realesrgan_x2" not in models:
logger.info("Loading RealESRGAN x2plus model...")
url = "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth"
model_path = download_file(url, "RealESRGAN_x2plus.pth")
model = ModelLoader().load_from_file(model_path)
model.eval()
# Move to CPU (or CUDA if available, but we focus on CPU here)
device = torch.device("cpu")
model.to(device)
models["realesrgan_x2"] = model
logger.info("RealESRGAN x2plus loaded.")
return models["realesrgan_x2"]
def load_swin2sr_x2():
if "swin2sr_x2" not in models:
logger.info("Loading Swin2SR x2 model...")
model_id = "caidas/swin2SR-classical-sr-x2-64"
processor = AutoImageProcessor.from_pretrained(model_id)
model = Swin2SRForImageSuperResolution.from_pretrained(model_id)
models["swin2sr_x2"] = (processor, model)
logger.info("Swin2SR x2 loaded.")
return models["swin2sr_x2"]
def load_sd_x4():
if "sd_x4" not in models:
logger.info("Loading Stable Diffusion x4 model (this might take a while)...")
model_id = "stabilityai/stable-diffusion-x4-upscaler"
pipe = StableDiffusionUpscalePipeline.from_pretrained(
model_id,
torch_dtype=torch.float32,
low_cpu_mem_usage=True
)
pipe.enable_attention_slicing("max")
pipe.enable_vae_tiling()
models["sd_x4"] = pipe
logger.info("Stable Diffusion x4 loaded.")
return models["sd_x4"]
def upscale_realesrgan(input_img):
model = load_realesrgan_x2()
# Convert PIL to Tensor
img_np = np.array(input_img).astype(np.float32) / 255.0
img_tensor = torch.from_numpy(img_np).permute(2, 0, 1).unsqueeze(0)
with torch.no_grad():
output_tensor = model(img_tensor)
# Convert Tensor back to PIL
output_np = output_tensor.squeeze(0).permute(1, 2, 0).clamp(0, 1).numpy()
output_np = (output_np * 255.0).round().astype(np.uint8)
return Image.fromarray(output_np)
def upscale_swin2sr(input_img, scale=2):
processor, model = load_swin2sr_x2()
inputs = processor(images=input_img, return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs)
output = outputs.reconstruction.data.squeeze().float().cpu().clamp_(0, 1).numpy()
output = np.moveaxis(output, source=0, destination=-1)
output = (output * 255.0).round().astype(np.uint8)
return Image.fromarray(output)
def upscale_diffusion_cpu(input_img, prompt):
pipe = load_sd_x4()
# Resize input if too large to prevent OOM
max_size = 512
if max(input_img.size) > max_size:
ratio = max_size / max(input_img.size)
new_size = (int(input_img.size[0] * ratio), int(input_img.size[1] * ratio))
input_img = input_img.resize(new_size, Image.Resampling.LANCZOS)
logger.warning(f"Resized input to {new_size} to prevent OOM")
generator = torch.manual_seed(42)
output = pipe(
prompt=prompt,
image=input_img,
num_inference_steps=20,
guidance_scale=7.0,
generator=generator
).images[0]
return output
def process_image(input_img, model_name, prompt):
if input_img is None:
return None, get_logs()
logger.info(f"Processing image with {model_name}...")
try:
if model_name == "RealESRGAN x2":
output = upscale_realesrgan(input_img)
elif model_name == "Swin2SR x2":
output = upscale_swin2sr(input_img, scale=2)
elif model_name == "Stable Diffusion x4":
output = upscale_diffusion_cpu(input_img, prompt)
else:
output = input_img # Fallback
gc.collect()
logger.info("Processing complete.")
return output, get_logs()
except Exception as e:
logger.error(f"Error: {str(e)}")
return None, get_logs()
desc = """
### Multi-Model Upscaler
Select a model to upscale your image.
* **RealESRGAN x2**: Very fast, sharp results. Best for general photos.
* **Swin2SR x2**: Accurate, good for compressed images. Slower than RealESRGAN.
* **Stable Diffusion x4**: Slow, creative, high memory usage. Adds details but may hallucinate.
"""
with gr.Blocks(title="Universal Upscaler") as iface:
gr.Markdown(desc)
with gr.Row():
with gr.Column():
input_image = gr.Image(type="pil", label="Input Image")
model_selector = gr.Dropdown(
choices=["RealESRGAN x2", "Swin2SR x2", "Stable Diffusion x4"],
value="RealESRGAN x2",
label="Select Model"
)
prompt_input = gr.Textbox(
label="Prompt (for Stable Diffusion only)",
value="highly detailed, 4k, sharp"
)
submit_btn = gr.Button("Upscale")
with gr.Column():
output_image = gr.Image(type="pil", label="Upscaled Image")
logs_output = gr.TextArea(label="Logs", interactive=False)
submit_btn.click(
fn=process_image,
inputs=[input_image, model_selector, prompt_input],
outputs=[output_image, logs_output]
)
iface.launch()