jgitsolutions's picture
Upload 3 files
e58a107 verified
raw
history blame
27.3 kB
import gradio as gr
import torch
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale import StableDiffusionUpscalePipeline
from transformers import AutoImageProcessor, Swin2SRForImageSuperResolution
import gc
from PIL import Image
import numpy as np
import logging
import io
import os
import requests
from spandrel import ModelLoader
from abc import ABC, abstractmethod
from typing import Optional, Tuple, Dict
import psutil
import time
import traceback
# --- Configuration ---
class Config:
"""Configuration settings for the application."""
MODEL_DIR = "weights"
REALESRGAN_URL = "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth"
REALESRGAN_FILENAME = "RealESRGAN_x2plus.pth"
SWIN2SR_ID = "caidas/swin2SR-classical-sr-x2-64"
SD_ID = "stabilityai/stable-diffusion-x4-upscaler"
# SOTA Models (2025)
SPAN_URL = "https://huggingface.co/Phips/2xNomosUni_span_multijpg/resolve/main/2xNomosUni_span_multijpg.safetensors"
SPAN_FILENAME = "2xNomosUni_span_multijpg.safetensors"
HATS_URL = "https://huggingface.co/Phips/4xNomos8kSCHAT-S/resolve/main/4xNomos8kSCHAT-S.safetensors"
HATS_FILENAME = "4xNomos8kSCHAT-S.safetensors"
MAX_IMAGE_SIZE_SD = 512 # Max dimension for SD input to prevent OOM
DEVICE = "cpu" # Force CPU for this demo, can be "cuda" if available
@staticmethod
def ensure_model_dir():
if not os.path.exists(Config.MODEL_DIR):
os.makedirs(Config.MODEL_DIR)
# --- Logging Setup ---
class LogCapture(io.StringIO):
"""Custom StringIO to capture logs."""
pass
log_capture_string = LogCapture()
ch = logging.StreamHandler(log_capture_string)
ch.setLevel(logging.INFO)
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
ch.setFormatter(formatter)
logger = logging.getLogger("UpscalerApp")
logger.setLevel(logging.INFO)
logger.addHandler(ch)
def get_logs() -> str:
"""Retrieve captured logs."""
return log_capture_string.getvalue()
# --- System Monitoring ---
def get_system_usage() -> str:
"""Returns current CPU and RAM usage."""
cpu_percent = psutil.cpu_percent()
ram_percent = psutil.virtual_memory().percent
ram_used_gb = psutil.virtual_memory().used / (1024 ** 3)
return f"CPU: {cpu_percent}% | RAM: {ram_percent}% ({ram_used_gb:.1f} GB used)"
# --- Abstract Base Class for Models ---
class UpscalerStrategy(ABC):
"""Abstract base class for upscaling strategies."""
def __init__(self):
self.model = None
self.name = "Unknown"
@abstractmethod
def load(self) -> None:
"""Load the model into memory."""
pass
@abstractmethod
def upscale(self, image: Image.Image, **kwargs) -> Image.Image:
"""Upscale the given image."""
pass
def unload(self) -> None:
"""Unload the model to free memory."""
if self.model is not None:
del self.model
self.model = None
gc.collect()
logger.info(f"Unloaded {self.name}")
# --- Helper Functions for Optimization ---
def manual_tile_upscale(model, img_tensor, tile_size=256, tile_pad=10, scale=2):
"""
Low-level tiling implementation for custom models.
Prevents OOM by processing image in chunks.
"""
B, C, H, W = img_tensor.shape
# Calculate tile dimensions
tile_h = (H + tile_size - 1) // tile_size
tile_w = (W + tile_size - 1) // tile_size
output = torch.zeros(B, C, H * scale, W * scale,
device=img_tensor.device, dtype=img_tensor.dtype)
for th in range(tile_h):
for tw in range(tile_w):
# Calculate input tile coordinates with padding
x1 = th * tile_size
y1 = tw * tile_size
x2 = min((th + 1) * tile_size, H)
y2 = min((tw + 1) * tile_size, W)
# Add halo for context
x1_pad = max(0, x1 - tile_pad)
y1_pad = max(0, y1 - tile_pad)
x2_pad = min(H, x2 + tile_pad)
y2_pad = min(W, y2 + tile_pad)
# Extract padded tile
tile = img_tensor[:, :, x1_pad:x2_pad, y1_pad:y2_pad]
# Process tile
with torch.no_grad():
tile_out = model(tile)
# Calculate output crop region (remove halo)
halo_x1 = (x1 - x1_pad) * scale
halo_y1 = (y1 - y1_pad) * scale
out_x2 = halo_x1 + (x2 - x1) * scale
out_y2 = halo_y1 + (y2 - y1) * scale
# Place in output
output[:, :, x1*scale:x2*scale, y1*scale:y2*scale] = \
tile_out[:, :, halo_x1:out_x2, halo_y1:out_y2]
return output
def select_tile_config(height, width):
"""
Dynamically select tile size based on image resolution.
"""
megapixels = (height * width) / (1024 ** 2)
if megapixels < 2: # < 1080p
return {'tile': 512, 'tile_pad': 10}
elif megapixels < 6: # < 4K
return {'tile': 384, 'tile_pad': 15}
elif megapixels < 16: # < 8K
return {'tile': 256, 'tile_pad': 20}
else: # 8K+
return {'tile': 128, 'tile_pad': 25}
# --- Concrete Implementations ---
class RealESRGANStrategy(UpscalerStrategy):
def __init__(self):
super().__init__()
self.name = "RealESRGAN x2"
self.compiled = False
def load(self) -> None:
if self.model is None:
logger.info(f"Loading {self.name}...")
Config.ensure_model_dir()
model_path = os.path.join(Config.MODEL_DIR, Config.REALESRGAN_FILENAME)
if not os.path.exists(model_path):
logger.info(f"Downloading {Config.REALESRGAN_FILENAME}...")
try:
response = requests.get(Config.REALESRGAN_URL, stream=True)
response.raise_for_status()
with open(model_path, 'wb') as f:
for chunk in response.iter_content(chunk_size=8192):
f.write(chunk)
logger.info("Download complete.")
except Exception as e:
logger.error(f"Failed to download model: {e}")
raise
try:
self.model = ModelLoader().load_from_file(model_path)
self.model.eval()
self.model.to(Config.DEVICE)
# Optimization: torch.compile
if not self.compiled:
try:
# 'reduce-overhead' uses CUDA graphs, so only use it on CUDA
if Config.DEVICE == 'cuda':
self.model = torch.compile(self.model, mode='reduce-overhead')
logger.info("✓ torch.compile enabled (reduce-overhead mode)")
elif os.name == 'nt' and Config.DEVICE == 'cpu':
# Windows requires MSVC for Inductor (default cpu backend)
# We skip it to avoid "Compiler: cl is not found" error unless user has it.
logger.info("ℹ Skipping torch.compile on Windows CPU to avoid MSVC requirement.")
elif (psutil.cpu_count(logical=False) or 0) < 4 and Config.DEVICE == 'cpu':
# Skip compilation on weak CPUs (e.g. HF Spaces Free Tier) to avoid long startup times
logger.info("ℹ Skipping torch.compile on low-core CPU to prevent timeout.")
else:
# On Linux/Mac CPU, use default mode or skip if problematic. Default is usually safe.
self.model = torch.compile(self.model)
logger.info("✓ torch.compile enabled (default mode)")
self.compiled = True
except Exception as e:
logger.warning(f"⚠ torch.compile not available or failed: {e}")
self.compiled = True # Mark as tried
logger.info(f"{self.name} loaded successfully.")
except Exception as e:
logger.error(f"Failed to load model architecture: {e}")
raise
def upscale(self, image: Image.Image, **kwargs) -> Image.Image:
if self.model is None:
self.load()
logger.info(f"Starting inference with {self.name}...")
start_time = time.time()
img_np = np.array(image).astype(np.float32) / 255.0
img_tensor = torch.from_numpy(img_np).permute(2, 0, 1).unsqueeze(0).to(Config.DEVICE)
# Optimization: Dynamic Tiling
h, w = img_np.shape[:2]
tile_config = select_tile_config(h, w)
logger.info(f"Using tile config: {tile_config}")
# Optimization: Mixed Precision (AMP)
# Use bfloat16 for CPU if supported, else float32 (autocast handles this mostly)
# For CUDA, float16 is standard.
dtype = torch.float16 if Config.DEVICE == 'cuda' else torch.bfloat16
try:
with torch.autocast(device_type=Config.DEVICE, dtype=dtype):
if tile_config['tile'] > 0:
output_tensor = manual_tile_upscale(
self.model,
img_tensor,
tile_size=tile_config['tile'],
tile_pad=tile_config['tile_pad'],
scale=2
)
else:
output_tensor = self.model(img_tensor) # type: ignore
except Exception as e:
logger.warning(f"AMP/Tiling failed, falling back to standard FP32: {e}")
# Fallback to standard execution
output_tensor = self.model(img_tensor) # type: ignore
output_np = output_tensor.squeeze(0).permute(1, 2, 0).clamp(0, 1).float().cpu().numpy()
output_np = (output_np * 255.0).round().astype(np.uint8)
elapsed = time.time() - start_time
logger.info(f"Inference finished in {elapsed:.2f}s")
# Benchmark info (from doc)
output_megapixels = (output_np.shape[0] * output_np.shape[1]) / (1024 ** 2)
throughput = output_megapixels / elapsed
logger.info(f"Speed: {throughput:.2f} MP/s")
return Image.fromarray(output_np)
class Swin2SRStrategy(UpscalerStrategy):
def __init__(self):
super().__init__()
self.name = "Swin2SR x2"
self.processor = None
def load(self) -> None:
if self.model is None:
logger.info(f"Loading {self.name}...")
try:
self.processor = AutoImageProcessor.from_pretrained(Config.SWIN2SR_ID)
model = Swin2SRForImageSuperResolution.from_pretrained(Config.SWIN2SR_ID)
self.model = model.to(Config.DEVICE) # type: ignore
logger.info(f"{self.name} loaded successfully.")
except Exception as e:
logger.error(f"Failed to load Swin2SR: {e}")
raise
def upscale(self, image: Image.Image, **kwargs) -> Image.Image:
if self.model is None or self.processor is None:
self.load()
logger.info(f"Starting inference with {self.name}...")
start_time = time.time()
if self.processor is None:
raise ValueError("Processor not loaded")
inputs = self.processor(images=image, return_tensors="pt").to(Config.DEVICE)
with torch.no_grad():
outputs = self.model(**inputs)
output = outputs.reconstruction.data.squeeze().float().cpu().clamp_(0, 1).numpy()
output = np.moveaxis(output, source=0, destination=-1)
output = (output * 255.0).round().astype(np.uint8)
logger.info(f"Inference finished in {time.time() - start_time:.2f}s")
return Image.fromarray(output)
class StableDiffusionStrategy(UpscalerStrategy):
def __init__(self):
super().__init__()
self.name = "Stable Diffusion x4"
def load(self) -> None:
if self.model is None:
logger.info(f"Loading {self.name} (this may take time)...")
try:
self.model = StableDiffusionUpscalePipeline.from_pretrained(
Config.SD_ID,
torch_dtype=torch.float32,
low_cpu_mem_usage=True
)
# Optimizations for CPU
self.model.enable_attention_slicing("max")
self.model.enable_vae_tiling()
logger.info(f"{self.name} loaded successfully.")
except Exception as e:
logger.error(f"Failed to load Stable Diffusion: {e}")
raise
def upscale(self, image: Image.Image, **kwargs) -> Image.Image:
if self.model is None:
self.load()
prompt = kwargs.get("prompt", "high quality, detailed")
# Pre-check size
if max(image.size) > Config.MAX_IMAGE_SIZE_SD:
ratio = Config.MAX_IMAGE_SIZE_SD / max(image.size)
new_size = (int(image.size[0] * ratio), int(image.size[1] * ratio))
image = image.resize(new_size, Image.Resampling.LANCZOS)
logger.warning(f"Resized input to {new_size} to prevent OOM on CPU.")
logger.info(f"Starting inference with {self.name}...")
start_time = time.time()
generator = torch.manual_seed(42)
output = self.model(
prompt=prompt,
image=image,
num_inference_steps=20,
guidance_scale=7.0,
generator=generator
).images[0] # type: ignore
logger.info(f"Inference finished in {time.time() - start_time:.2f}s")
return output
class SpanStrategy(UpscalerStrategy):
def __init__(self):
super().__init__()
self.name = "SPAN (NomosUni) x2"
self.compiled = False
def load(self) -> None:
if self.model is None:
logger.info(f"Loading {self.name}...")
Config.ensure_model_dir()
model_path = os.path.join(Config.MODEL_DIR, Config.SPAN_FILENAME)
if not os.path.exists(model_path):
logger.info(f"Downloading {Config.SPAN_FILENAME}...")
try:
response = requests.get(Config.SPAN_URL, stream=True)
response.raise_for_status()
with open(model_path, 'wb') as f:
for chunk in response.iter_content(chunk_size=8192):
f.write(chunk)
logger.info("Download complete.")
except Exception as e:
logger.error(f"Failed to download model: {e}")
raise
try:
self.model = ModelLoader().load_from_file(model_path)
self.model.eval()
self.model.to(Config.DEVICE)
# Optimization: torch.compile
if not self.compiled:
try:
if Config.DEVICE == 'cuda':
self.model = torch.compile(self.model, mode='reduce-overhead')
logger.info("✓ torch.compile enabled (reduce-overhead mode)")
elif os.name == 'nt' and Config.DEVICE == 'cpu':
logger.info("ℹ Skipping torch.compile on Windows CPU.")
elif (psutil.cpu_count(logical=False) or 0) < 4 and Config.DEVICE == 'cpu':
logger.info("ℹ Skipping torch.compile on low-core CPU.")
else:
self.model = torch.compile(self.model)
logger.info("✓ torch.compile enabled (default mode)")
self.compiled = True
except Exception as e:
logger.warning(f"⚠ torch.compile failed: {e}")
self.compiled = True
logger.info(f"{self.name} loaded successfully.")
except Exception as e:
logger.error(f"Failed to load model architecture: {e}")
raise
def upscale(self, image: Image.Image, **kwargs) -> Image.Image:
if self.model is None:
self.load()
logger.info(f"Starting inference with {self.name}...")
start_time = time.time()
img_np = np.array(image).astype(np.float32) / 255.0
img_tensor = torch.from_numpy(img_np).permute(2, 0, 1).unsqueeze(0).to(Config.DEVICE)
# SPAN is very efficient, but we still use tiling for safety on huge images
h, w = img_np.shape[:2]
tile_config = select_tile_config(h, w)
dtype = torch.float16 if Config.DEVICE == 'cuda' else torch.bfloat16
try:
with torch.autocast(device_type=Config.DEVICE, dtype=dtype):
if tile_config['tile'] > 0:
output_tensor = manual_tile_upscale(
self.model,
img_tensor,
tile_size=tile_config['tile'],
tile_pad=tile_config['tile_pad'],
scale=2
)
else:
output_tensor = self.model(img_tensor) # type: ignore
except Exception as e:
logger.warning(f"AMP/Tiling failed, falling back: {e}")
output_tensor = self.model(img_tensor) # type: ignore
output_np = output_tensor.squeeze(0).permute(1, 2, 0).clamp(0, 1).float().cpu().numpy()
output_np = (output_np * 255.0).round().astype(np.uint8)
elapsed = time.time() - start_time
logger.info(f"Inference finished in {elapsed:.2f}s")
return Image.fromarray(output_np)
class HatsStrategy(UpscalerStrategy):
def __init__(self):
super().__init__()
self.name = "HAT-S x4"
self.compiled = False
def load(self) -> None:
if self.model is None:
logger.info(f"Loading {self.name}...")
Config.ensure_model_dir()
model_path = os.path.join(Config.MODEL_DIR, Config.HATS_FILENAME)
if not os.path.exists(model_path):
logger.info(f"Downloading {Config.HATS_FILENAME}...")
try:
response = requests.get(Config.HATS_URL, stream=True)
response.raise_for_status()
with open(model_path, 'wb') as f:
for chunk in response.iter_content(chunk_size=8192):
f.write(chunk)
logger.info("Download complete.")
except Exception as e:
logger.error(f"Failed to download model: {e}")
raise
try:
self.model = ModelLoader().load_from_file(model_path)
self.model.eval()
self.model.to(Config.DEVICE)
if not self.compiled:
try:
if Config.DEVICE == 'cuda':
self.model = torch.compile(self.model, mode='reduce-overhead')
elif os.name == 'nt' and Config.DEVICE == 'cpu':
pass
elif (psutil.cpu_count(logical=False) or 0) < 4 and Config.DEVICE == 'cpu':
pass
else:
self.model = torch.compile(self.model)
self.compiled = True
except Exception:
self.compiled = True
logger.info(f"{self.name} loaded successfully.")
except Exception as e:
logger.error(f"Failed to load model architecture: {e}")
raise
def upscale(self, image: Image.Image, **kwargs) -> Image.Image:
if self.model is None:
self.load()
logger.info(f"Starting inference with {self.name}...")
start_time = time.time()
img_np = np.array(image).astype(np.float32) / 255.0
img_tensor = torch.from_numpy(img_np).permute(2, 0, 1).unsqueeze(0).to(Config.DEVICE)
h, w = img_np.shape[:2]
tile_config = select_tile_config(h, w)
dtype = torch.float16 if Config.DEVICE == 'cuda' else torch.bfloat16
try:
with torch.autocast(device_type=Config.DEVICE, dtype=dtype):
if tile_config['tile'] > 0:
output_tensor = manual_tile_upscale(
self.model,
img_tensor,
tile_size=tile_config['tile'],
tile_pad=tile_config['tile_pad'],
scale=4 # HAT-S is x4
)
else:
output_tensor = self.model(img_tensor) # type: ignore
except Exception as e:
logger.warning(f"AMP/Tiling failed, falling back: {e}")
output_tensor = self.model(img_tensor) # type: ignore
output_np = output_tensor.squeeze(0).permute(1, 2, 0).clamp(0, 1).float().cpu().numpy()
output_np = (output_np * 255.0).round().astype(np.uint8)
elapsed = time.time() - start_time
logger.info(f"Inference finished in {elapsed:.2f}s")
return Image.fromarray(output_np)
# --- Model Manager (Singleton-ish) ---
class UpscalerManager:
"""Manages model lifecycle and selection."""
def __init__(self):
self.strategies: Dict[str, UpscalerStrategy] = {
"SPAN (NomosUni) x2": SpanStrategy(),
"RealESRGAN x2": RealESRGANStrategy(),
"HAT-S x4": HatsStrategy(),
"Swin2SR x2": Swin2SRStrategy(),
"Stable Diffusion x4": StableDiffusionStrategy()
}
self.current_model_name: Optional[str] = None
def get_strategy(self, name: str) -> UpscalerStrategy:
if name not in self.strategies:
raise ValueError(f"Model {name} not found.")
# Memory Optimization for Free Tier (16GB RAM limit):
# Ensure only one model is loaded at a time.
if self.current_model_name != name:
if self.current_model_name is not None:
logger.info(f"Switching models: Unloading {self.current_model_name}...")
self.strategies[self.current_model_name].unload()
self.current_model_name = name
return self.strategies[name]
def unload_all(self):
"""Unload all models to free memory."""
for strategy in self.strategies.values():
strategy.unload()
gc.collect()
logger.info("All models unloaded.")
manager = UpscalerManager()
# --- Gradio Interface Logic ---
def process_image(input_img: Image.Image, model_name: str, prompt: str) -> Tuple[Optional[Image.Image], str, str]:
if input_img is None:
return None, get_logs(), get_system_usage()
try:
strategy = manager.get_strategy(model_name)
# Optional: Unload others if memory is tight (simple logic here)
# For now, we just rely on the user or OS, but in prod we might auto-unload.
output = strategy.upscale(input_img, prompt=prompt)
# Explicit GC after heavy operations
gc.collect()
return output, get_logs(), get_system_usage()
except Exception as e:
error_msg = f"Critical Error: {str(e)}\n{traceback.format_exc()}"
logger.error(error_msg)
# Return the error message in the logs output so the user sees it
return None, get_logs() + "\n\n" + error_msg, get_system_usage()
def unload_models():
manager.unload_all()
return get_logs(), get_system_usage()
# --- UI Construction ---
desc = """
### 🚀 Enterprise-Grade Universal Upscaler (SOTA 2025)
Select a specialized model to upscale your image.
* **SPAN (NomosUni) x2**: ⚡ **SOTA Speed**. Fastest CPU model. Best for general use.
* **RealESRGAN x2**: 🛡️ **Robust**. Best for removing JPEG artifacts and noise.
* **HAT-S x4**: 💎 **SOTA Quality**. Best texture detail (slower).
* **Swin2SR x2**: 🎯 High fidelity, removes compression artifacts.
* **Stable Diffusion x4**: 🎨 Generative upscaling. Adds missing details (slow, high RAM).
"""
with gr.Blocks(title="Universal Upscaler Pro") as iface:
gr.Markdown(desc)
with gr.Row():
with gr.Column(scale=1):
input_image = gr.Image(type="pil", label="Input Image")
with gr.Group():
model_selector = gr.Dropdown(
choices=list(manager.strategies.keys()),
value="SPAN (NomosUni) x2",
label="Select Model Architecture"
)
prompt_input = gr.Textbox(
label="Prompt (Stable Diffusion Only)",
value="highly detailed, 4k, sharp",
placeholder="Describe the image content..."
)
with gr.Accordion("Advanced Settings", open=False):
gr.Markdown("Memory Management")
unload_btn = gr.Button("Unload All Models (Free RAM)", variant="secondary")
submit_btn = gr.Button("✨ Upscale Image", variant="primary", size="lg")
system_info = gr.Label(value=get_system_usage(), label="System Status")
with gr.Column(scale=1):
output_image = gr.Image(type="pil", label="Upscaled Result")
logs_output = gr.TextArea(label="Execution Logs", interactive=False, lines=10)
# Event Wiring
submit_btn.click(
fn=process_image,
inputs=[input_image, model_selector, prompt_input],
outputs=[output_image, logs_output, system_info]
)
unload_btn.click(
fn=unload_models,
inputs=[],
outputs=[logs_output, system_info]
)
# Auto-refresh system info every 2 seconds (optional, can be heavy on UI)
# iface.load(get_system_usage, None, system_info, every=2)
iface.launch()