jgitsolutions's picture
Update app.py
c0f4292 verified
import gradio as gr
import torch
import gc
from PIL import Image
import numpy as np
import logging
import io
import os
import requests
from spandrel import ModelLoader
from abc import ABC, abstractmethod
from typing import Optional, Tuple, Dict
import psutil
import time
import traceback
# --- Configuration ---
class Config:
"""Configuration settings for the application."""
MODEL_DIR = "."
REALESRGAN_URL = "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth"
REALESRGAN_FILENAME = "RealESRGAN_x2plus.pth"
# SOTA Models (2025)
SPAN_URL = "https://huggingface.co/Phips/2xNomosUni_span_multijpg/resolve/main/2xNomosUni_span_multijpg.safetensors"
SPAN_FILENAME = "2xNomosUni_span_multijpg.safetensors"
HATS_URL = "https://huggingface.co/Phips/4xNomos8kSCHAT-S/resolve/main/4xNomos8kSCHAT-S.safetensors"
HATS_FILENAME = "4xNomos8kSCHAT-S.safetensors"
DEVICE = "cpu" # Force CPU for this demo, can be "cuda" if available
@staticmethod
def ensure_model_dir():
if not os.path.exists(Config.MODEL_DIR):
os.makedirs(Config.MODEL_DIR)
# --- Logging Setup ---
class LogCapture(io.StringIO):
"""Custom StringIO to capture logs."""
pass
log_capture_string = LogCapture()
ch = logging.StreamHandler(log_capture_string)
ch.setLevel(logging.INFO)
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
ch.setFormatter(formatter)
logger = logging.getLogger("UpscalerApp")
logger.setLevel(logging.INFO)
logger.addHandler(ch)
def get_logs() -> str:
"""Retrieve captured logs."""
return log_capture_string.getvalue()
# --- System Monitoring ---
def get_system_usage() -> str:
"""Returns current CPU and RAM usage."""
cpu_percent = psutil.cpu_percent()
ram_percent = psutil.virtual_memory().percent
ram_used_gb = psutil.virtual_memory().used / (1024 ** 3)
return f"CPU: {cpu_percent}% | RAM: {ram_percent}% ({ram_used_gb:.1f} GB used)"
# --- Abstract Base Class for Models ---
class UpscalerStrategy(ABC):
"""Abstract base class for upscaling strategies."""
def __init__(self):
self.model = None
self.name = "Unknown"
@abstractmethod
def load(self) -> None:
"""Load the model into memory."""
pass
@abstractmethod
def upscale(self, image: Image.Image, **kwargs) -> Image.Image:
"""Upscale the given image."""
pass
def unload(self) -> None:
"""Unload the model to free memory."""
if self.model is not None:
del self.model
self.model = None
gc.collect()
logger.info(f"Unloaded {self.name}")
# --- Helper Functions for Optimization ---
def manual_tile_upscale(model, img_tensor, tile_size=256, tile_pad=10, scale=2):
"""
Low-level tiling implementation for custom models.
Prevents OOM by processing image in chunks.
"""
B, C, H, W = img_tensor.shape
# Calculate tile dimensions
tile_h = (H + tile_size - 1) // tile_size
tile_w = (W + tile_size - 1) // tile_size
output = torch.zeros(B, C, H * scale, W * scale,
device=img_tensor.device, dtype=img_tensor.dtype)
for th in range(tile_h):
for tw in range(tile_w):
# Calculate input tile coordinates with padding
x1 = th * tile_size
y1 = tw * tile_size
x2 = min((th + 1) * tile_size, H)
y2 = min((tw + 1) * tile_size, W)
# Add halo for context
x1_pad = max(0, x1 - tile_pad)
y1_pad = max(0, y1 - tile_pad)
x2_pad = min(H, x2 + tile_pad)
y2_pad = min(W, y2 + tile_pad)
# Extract padded tile
tile = img_tensor[:, :, x1_pad:x2_pad, y1_pad:y2_pad]
# Process tile
with torch.no_grad():
tile_out = model(tile)
# Calculate output crop region (remove halo)
halo_x1 = (x1 - x1_pad) * scale
halo_y1 = (y1 - y1_pad) * scale
out_x2 = halo_x1 + (x2 - x1) * scale
out_y2 = halo_y1 + (y2 - y1) * scale
# Place in output
output[:, :, x1*scale:x2*scale, y1*scale:y2*scale] = \
tile_out[:, :, halo_x1:out_x2, halo_y1:out_y2]
return output
def select_tile_config(height, width):
"""
Dynamically select tile size based on image resolution.
"""
megapixels = (height * width) / (1024 ** 2)
if megapixels < 2: # < 1080p
return {'tile': 512, 'tile_pad': 10}
elif megapixels < 6: # < 4K
return {'tile': 384, 'tile_pad': 15}
elif megapixels < 16: # < 8K
return {'tile': 256, 'tile_pad': 20}
else: # 8K+
return {'tile': 128, 'tile_pad': 25}
# --- Concrete Implementations ---
class RealESRGANStrategy(UpscalerStrategy):
def __init__(self):
super().__init__()
self.name = "RealESRGAN x2"
self.compiled = False
def load(self) -> None:
if self.model is None:
logger.info(f"Loading {self.name}...")
Config.ensure_model_dir()
model_path = os.path.join(Config.MODEL_DIR, Config.REALESRGAN_FILENAME)
if not os.path.exists(model_path):
logger.info(f"Downloading {Config.REALESRGAN_FILENAME}...")
try:
response = requests.get(Config.REALESRGAN_URL, stream=True)
response.raise_for_status()
with open(model_path, 'wb') as f:
for chunk in response.iter_content(chunk_size=8192):
f.write(chunk)
logger.info("Download complete.")
except Exception as e:
logger.error(f"Failed to download model: {e}")
raise
try:
self.model = ModelLoader().load_from_file(model_path)
self.model.eval()
self.model.to(Config.DEVICE)
# Optimization: torch.compile
if not self.compiled:
try:
# 'reduce-overhead' uses CUDA graphs, so only use it on CUDA
if Config.DEVICE == 'cuda':
self.model = torch.compile(self.model, mode='reduce-overhead')
logger.info("[INFO] torch.compile enabled (reduce-overhead mode)")
elif os.name == 'nt' and Config.DEVICE == 'cpu':
# Windows requires MSVC for Inductor (default cpu backend)
# We skip it to avoid "Compiler: cl is not found" error unless user has it.
logger.info("[INFO] Skipping torch.compile on Windows CPU to avoid MSVC requirement.")
elif (psutil.cpu_count(logical=False) or 0) < 4 and Config.DEVICE == 'cpu':
# Skip compilation on weak CPUs (e.g. HF Spaces Free Tier) to avoid long startup times
logger.info("[INFO] Skipping torch.compile on low-core CPU to prevent timeout.")
else:
# On Linux/Mac CPU, use default mode or skip if problematic. Default is usually safe.
self.model = torch.compile(self.model)
logger.info("[SUCCESS] torch.compile enabled (default mode)")
self.compiled = True
except Exception as e:
logger.warning(f"[WARNING] torch.compile not available or failed: {e}")
self.compiled = True # Mark as tried
logger.info(f"{self.name} loaded successfully.")
except Exception as e:
logger.error(f"Failed to load model architecture: {e}")
raise
def upscale(self, image: Image.Image, **kwargs) -> Image.Image:
if self.model is None:
self.load()
logger.info(f"Starting inference with {self.name}...")
start_time = time.time()
img_np = np.array(image).astype(np.float32) / 255.0
img_tensor = torch.from_numpy(img_np).permute(2, 0, 1).unsqueeze(0).to(Config.DEVICE)
# Optimization: Dynamic Tiling
h, w = img_np.shape[:2]
tile_config = select_tile_config(h, w)
logger.info(f"Using tile config: {tile_config}")
# Optimization: Mixed Precision (AMP)
# Use bfloat16 for CPU if supported, else float32 (autocast handles this mostly)
# For CUDA, float16 is standard.
dtype = torch.float16 if Config.DEVICE == 'cuda' else torch.bfloat16
try:
# Explicitly disable autocast on CPU for RealESRGAN to avoid "PythonFallbackKernel" errors
# This seems to be a regression in recent PyTorch versions on CPU with some ops
context = torch.autocast(device_type=Config.DEVICE, dtype=dtype) if Config.DEVICE != 'cpu' else torch.no_grad()
with context:
if tile_config['tile'] > 0:
output_tensor = manual_tile_upscale(
self.model,
img_tensor,
tile_size=tile_config['tile'],
tile_pad=tile_config['tile_pad'],
scale=2
)
else:
output_tensor = self.model(img_tensor) # type: ignore
except Exception as e:
logger.warning(f"AMP/Tiling failed, falling back to standard FP32: {e}")
# Fallback to standard execution
output_tensor = self.model(img_tensor) # type: ignore
output_np = output_tensor.squeeze(0).permute(1, 2, 0).clamp(0, 1).float().cpu().numpy()
output_np = (output_np * 255.0).round().astype(np.uint8)
elapsed = time.time() - start_time
logger.info(f"Inference finished in {elapsed:.2f}s")
# Benchmark info (from doc)
output_megapixels = (output_np.shape[0] * output_np.shape[1]) / (1024 ** 2)
throughput = output_megapixels / elapsed
logger.info(f"Speed: {throughput:.2f} MP/s")
return Image.fromarray(output_np)
class SpanStrategy(UpscalerStrategy):
def __init__(self):
super().__init__()
self.name = "SPAN (NomosUni) x2"
self.compiled = False
def load(self) -> None:
if self.model is None:
logger.info(f"Loading {self.name}...")
Config.ensure_model_dir()
model_path = os.path.join(Config.MODEL_DIR, Config.SPAN_FILENAME)
if not os.path.exists(model_path):
logger.info(f"Downloading {Config.SPAN_FILENAME}...")
try:
response = requests.get(Config.SPAN_URL, stream=True)
response.raise_for_status()
with open(model_path, 'wb') as f:
for chunk in response.iter_content(chunk_size=8192):
f.write(chunk)
logger.info("Download complete.")
except Exception as e:
logger.error(f"Failed to download model: {e}")
raise
try:
self.model = ModelLoader().load_from_file(model_path)
self.model.eval()
self.model.to(Config.DEVICE)
# Optimization: torch.compile
if not self.compiled:
try:
if Config.DEVICE == 'cuda':
self.model = torch.compile(self.model, mode='reduce-overhead')
logger.info("[INFO] torch.compile enabled (reduce-overhead mode)")
elif os.name == 'nt' and Config.DEVICE == 'cpu':
logger.info("[INFO] Skipping torch.compile on Windows CPU.")
elif (psutil.cpu_count(logical=False) or 0) < 4 and Config.DEVICE == 'cpu':
logger.info("[INFO] Skipping torch.compile on low-core CPU.")
else:
# SPAN architecture uses .data.clone() in forward pass which breaks torch.compile/inductor
logger.info("[INFO] Skipping torch.compile for SPAN (incompatible architecture).")
# self.model = torch.compile(self.model)
self.compiled = True
except Exception:
self.compiled = True
logger.info(f"{self.name} loaded successfully.")
except Exception as e:
logger.error(f"Failed to load model architecture: {e}")
raise
def upscale(self, image: Image.Image, **kwargs) -> Image.Image:
if self.model is None:
self.load()
logger.info(f"Starting inference with {self.name}...")
start_time = time.time()
img_np = np.array(image).astype(np.float32) / 255.0
img_tensor = torch.from_numpy(img_np).permute(2, 0, 1).unsqueeze(0).to(Config.DEVICE)
# SPAN is very efficient, but we still use tiling for safety on huge images
h, w = img_np.shape[:2]
tile_config = select_tile_config(h, w)
# Disable AMP for SPAN on CPU to avoid "UntypedStorage" weakref errors in inductor
# SPAN architecture seems sensitive to autocast + compile on CPU
dtype = torch.float32 if Config.DEVICE == 'cpu' else torch.float16
try:
# Only use autocast if not CPU or if explicitly desired
context = torch.autocast(device_type=Config.DEVICE, dtype=dtype) if Config.DEVICE != 'cpu' else torch.no_grad()
with context:
if tile_config['tile'] > 0:
output_tensor = manual_tile_upscale(
self.model,
img_tensor,
tile_size=tile_config['tile'],
tile_pad=tile_config['tile_pad'],
scale=2
)
else:
output_tensor = self.model(img_tensor) # type: ignore
except Exception as e:
logger.warning(f"AMP/Tiling failed, falling back: {e}")
output_tensor = self.model(img_tensor) # type: ignore
output_np = output_tensor.squeeze(0).permute(1, 2, 0).clamp(0, 1).float().cpu().numpy()
output_np = (output_np * 255.0).round().astype(np.uint8)
elapsed = time.time() - start_time
logger.info(f"Inference finished in {elapsed:.2f}s")
return Image.fromarray(output_np)
class HatsStrategy(UpscalerStrategy):
def __init__(self):
super().__init__()
self.name = "HAT-S x4"
self.compiled = False
def load(self) -> None:
if self.model is None:
logger.info(f"Loading {self.name}...")
Config.ensure_model_dir()
model_path = os.path.join(Config.MODEL_DIR, Config.HATS_FILENAME)
if not os.path.exists(model_path):
logger.info(f"Downloading {Config.HATS_FILENAME}...")
try:
response = requests.get(Config.HATS_URL, stream=True)
response.raise_for_status()
with open(model_path, 'wb') as f:
for chunk in response.iter_content(chunk_size=8192):
f.write(chunk)
logger.info("Download complete.")
except Exception as e:
logger.error(f"Failed to download model: {e}")
raise
try:
self.model = ModelLoader().load_from_file(model_path)
self.model.eval()
self.model.to(Config.DEVICE)
if not self.compiled:
try:
if Config.DEVICE == 'cuda':
self.model = torch.compile(self.model, mode='reduce-overhead')
elif os.name == 'nt' and Config.DEVICE == 'cpu':
pass
elif (psutil.cpu_count(logical=False) or 0) < 4 and Config.DEVICE == 'cpu':
pass
else:
# HAT architecture also triggers "UntypedStorage" weakref errors with inductor on CPU
logger.info("[INFO] Skipping torch.compile for HAT-S (incompatible architecture).")
# self.model = torch.compile(self.model)
self.compiled = True
except Exception:
self.compiled = True
logger.info(f"{self.name} loaded successfully.")
except Exception as e:
logger.error(f"Failed to load model architecture: {e}")
raise
def upscale(self, image: Image.Image, **kwargs) -> Image.Image:
if self.model is None:
self.load()
logger.info(f"Starting inference with {self.name}...")
start_time = time.time()
img_np = np.array(image).astype(np.float32) / 255.0
img_tensor = torch.from_numpy(img_np).permute(2, 0, 1).unsqueeze(0).to(Config.DEVICE)
h, w = img_np.shape[:2]
tile_config = select_tile_config(h, w)
dtype = torch.float16 if Config.DEVICE == 'cuda' else torch.float32
try:
context = torch.autocast(device_type=Config.DEVICE, dtype=dtype) if Config.DEVICE != 'cpu' else torch.no_grad()
with context:
if tile_config['tile'] > 0:
output_tensor = manual_tile_upscale(
self.model,
img_tensor,
tile_size=tile_config['tile'],
tile_pad=tile_config['tile_pad'],
scale=4 # HAT-S is x4
)
else:
output_tensor = self.model(img_tensor) # type: ignore
except Exception as e:
logger.warning(f"AMP/Tiling failed, falling back: {e}")
output_tensor = self.model(img_tensor) # type: ignore
output_np = output_tensor.squeeze(0).permute(1, 2, 0).clamp(0, 1).float().cpu().numpy()
output_np = (output_np * 255.0).round().astype(np.uint8)
elapsed = time.time() - start_time
logger.info(f"Inference finished in {elapsed:.2f}s")
return Image.fromarray(output_np)
# --- Model Manager (Singleton-ish) ---
class UpscalerManager:
"""Manages model lifecycle and selection."""
def __init__(self):
self.strategies: Dict[str, UpscalerStrategy] = {
"SPAN (NomosUni) x2": SpanStrategy(),
"RealESRGAN x2": RealESRGANStrategy(),
"HAT-S x4": HatsStrategy()
}
self.current_model_name: Optional[str] = None
def get_strategy(self, name: str) -> UpscalerStrategy:
if name not in self.strategies:
raise ValueError(f"Model {name} not found.")
# Memory Optimization for Free Tier (16GB RAM limit):
# Ensure only one model is loaded at a time.
if self.current_model_name != name:
if self.current_model_name is not None:
logger.info(f"Switching models: Unloading {self.current_model_name}...")
self.strategies[self.current_model_name].unload()
self.current_model_name = name
return self.strategies[name]
def unload_all(self):
"""Unload all models to free memory."""
for strategy in self.strategies.values():
strategy.unload()
gc.collect()
logger.info("All models unloaded.")
manager = UpscalerManager()
# --- Gradio Interface Logic ---
def process_image(input_img: Image.Image, model_name: str, output_format: str) -> Tuple[Optional[str], str, str]:
if input_img is None:
return None, get_logs(), get_system_usage()
try:
strategy = manager.get_strategy(model_name)
output_img = strategy.upscale(input_img)
# Save to temp file with correct extension
output_path = f"output.{output_format.lower()}"
# Convert to RGB if saving as JPEG (doesn't support alpha)
if output_format.lower() in ['jpeg', 'jpg'] and output_img.mode == 'RGBA':
output_img = output_img.convert('RGB')
output_img.save(output_path, format=output_format)
# Explicit GC after heavy operations
gc.collect()
return output_path, get_logs(), get_system_usage()
except Exception as e:
error_msg = f"Critical Error: {str(e)}\n{traceback.format_exc()}"
logger.error(error_msg)
return None, get_logs() + "\n\n" + error_msg, get_system_usage()
def unload_models():
manager.unload_all()
return get_logs(), get_system_usage()
# --- UI Construction ---
desc = """
# Universal Upscaler Pro (CPU Optimized)
This application provides state-of-the-art (SOTA) image upscaling running entirely on CPU, optimized for free-tier cloud environments.
### Available Models
| Model | Scale | Best For | License |
| :--- | :--- | :--- | :--- |
| **SPAN (NomosUni)** | x2 | **Speed & General Use**. Extremely fast, parameter-free attention network. | Apache 2.0 |
| **RealESRGAN** | x2 | **Robustness**. Excellent at removing JPEG artifacts and noise. | BSD 3-Clause |
| **HAT-S** | x4 | **Texture Detail**. Hybrid Attention Transformer for high-fidelity restoration. | MIT |
### Attributions & Credits
* **Real-ESRGAN**: [Wang et al., 2021](https://github.com/xinntao/Real-ESRGAN). *Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data*.
* **SPAN**: [Zhang et al., 2023](https://github.com/hongyuanyu/SPAN). *Swift Parameter-free Attention Network for Efficient Super-Resolution*.
* **HAT**: [Chen et al., 2023](https://github.com/XPixelGroup/HAT). *Activating Activation Functions for Image Restoration*.
* **NomosUni**: Custom SPAN training by [Phhofm](https://github.com/Phhofm).
"""
with gr.Blocks(title="Universal Upscaler Pro") as iface:
gr.Markdown(desc)
with gr.Row():
with gr.Column(scale=1, min_width=300):
input_image = gr.Image(type="pil", label="Input Image", height=400)
with gr.Row():
model_selector = gr.Dropdown(
choices=list(manager.strategies.keys()),
value="SPAN (NomosUni) x2",
label="Model Architecture",
scale=2
)
output_format = gr.Dropdown(
choices=["PNG", "JPEG", "WEBP"],
value="PNG",
label="Output Format",
scale=1
)
submit_btn = gr.Button("Upscale Image", variant="primary", size="lg")
with gr.Accordion("Advanced Settings", open=False):
unload_btn = gr.Button("Unload All Models (Free RAM)", variant="secondary")
system_info = gr.Label(value=get_system_usage(), label="System Status")
with gr.Column(scale=1, min_width=300):
output_image = gr.Image(type="filepath", label="Upscaled Result", height=400)
logs_output = gr.TextArea(label="Execution Logs", interactive=False, lines=8)
# Event Wiring
submit_btn.click(
fn=process_image,
inputs=[input_image, model_selector, output_format],
outputs=[output_image, logs_output, system_info]
)
unload_btn.click(
fn=unload_models,
inputs=[],
outputs=[logs_output, system_info]
)
iface.launch()