import gradio as gr import torch import gc from PIL import Image import numpy as np import logging import io import os import requests from spandrel import ModelLoader from abc import ABC, abstractmethod from typing import Optional, Tuple, Dict import psutil import time import traceback # --- Configuration --- class Config: """Configuration settings for the application.""" MODEL_DIR = "." REALESRGAN_URL = "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth" REALESRGAN_FILENAME = "RealESRGAN_x2plus.pth" # SOTA Models (2025) SPAN_URL = "https://huggingface.co/Phips/2xNomosUni_span_multijpg/resolve/main/2xNomosUni_span_multijpg.safetensors" SPAN_FILENAME = "2xNomosUni_span_multijpg.safetensors" HATS_URL = "https://huggingface.co/Phips/4xNomos8kSCHAT-S/resolve/main/4xNomos8kSCHAT-S.safetensors" HATS_FILENAME = "4xNomos8kSCHAT-S.safetensors" DEVICE = "cpu" # Force CPU for this demo, can be "cuda" if available @staticmethod def ensure_model_dir(): if not os.path.exists(Config.MODEL_DIR): os.makedirs(Config.MODEL_DIR) # --- Logging Setup --- class LogCapture(io.StringIO): """Custom StringIO to capture logs.""" pass log_capture_string = LogCapture() ch = logging.StreamHandler(log_capture_string) ch.setLevel(logging.INFO) formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s') ch.setFormatter(formatter) logger = logging.getLogger("UpscalerApp") logger.setLevel(logging.INFO) logger.addHandler(ch) def get_logs() -> str: """Retrieve captured logs.""" return log_capture_string.getvalue() # --- System Monitoring --- def get_system_usage() -> str: """Returns current CPU and RAM usage.""" cpu_percent = psutil.cpu_percent() ram_percent = psutil.virtual_memory().percent ram_used_gb = psutil.virtual_memory().used / (1024 ** 3) return f"CPU: {cpu_percent}% | RAM: {ram_percent}% ({ram_used_gb:.1f} GB used)" # --- Abstract Base Class for Models --- class UpscalerStrategy(ABC): """Abstract base class for upscaling strategies.""" def __init__(self): self.model = None self.name = "Unknown" @abstractmethod def load(self) -> None: """Load the model into memory.""" pass @abstractmethod def upscale(self, image: Image.Image, **kwargs) -> Image.Image: """Upscale the given image.""" pass def unload(self) -> None: """Unload the model to free memory.""" if self.model is not None: del self.model self.model = None gc.collect() logger.info(f"Unloaded {self.name}") # --- Helper Functions for Optimization --- def manual_tile_upscale(model, img_tensor, tile_size=256, tile_pad=10, scale=2): """ Low-level tiling implementation for custom models. Prevents OOM by processing image in chunks. """ B, C, H, W = img_tensor.shape # Calculate tile dimensions tile_h = (H + tile_size - 1) // tile_size tile_w = (W + tile_size - 1) // tile_size output = torch.zeros(B, C, H * scale, W * scale, device=img_tensor.device, dtype=img_tensor.dtype) for th in range(tile_h): for tw in range(tile_w): # Calculate input tile coordinates with padding x1 = th * tile_size y1 = tw * tile_size x2 = min((th + 1) * tile_size, H) y2 = min((tw + 1) * tile_size, W) # Add halo for context x1_pad = max(0, x1 - tile_pad) y1_pad = max(0, y1 - tile_pad) x2_pad = min(H, x2 + tile_pad) y2_pad = min(W, y2 + tile_pad) # Extract padded tile tile = img_tensor[:, :, x1_pad:x2_pad, y1_pad:y2_pad] # Process tile with torch.no_grad(): tile_out = model(tile) # Calculate output crop region (remove halo) halo_x1 = (x1 - x1_pad) * scale halo_y1 = (y1 - y1_pad) * scale out_x2 = halo_x1 + (x2 - x1) * scale out_y2 = halo_y1 + (y2 - y1) * scale # Place in output output[:, :, x1*scale:x2*scale, y1*scale:y2*scale] = \ tile_out[:, :, halo_x1:out_x2, halo_y1:out_y2] return output def select_tile_config(height, width): """ Dynamically select tile size based on image resolution. """ megapixels = (height * width) / (1024 ** 2) if megapixels < 2: # < 1080p return {'tile': 512, 'tile_pad': 10} elif megapixels < 6: # < 4K return {'tile': 384, 'tile_pad': 15} elif megapixels < 16: # < 8K return {'tile': 256, 'tile_pad': 20} else: # 8K+ return {'tile': 128, 'tile_pad': 25} # --- Concrete Implementations --- class RealESRGANStrategy(UpscalerStrategy): def __init__(self): super().__init__() self.name = "RealESRGAN x2" self.compiled = False def load(self) -> None: if self.model is None: logger.info(f"Loading {self.name}...") Config.ensure_model_dir() model_path = os.path.join(Config.MODEL_DIR, Config.REALESRGAN_FILENAME) if not os.path.exists(model_path): logger.info(f"Downloading {Config.REALESRGAN_FILENAME}...") try: response = requests.get(Config.REALESRGAN_URL, stream=True) response.raise_for_status() with open(model_path, 'wb') as f: for chunk in response.iter_content(chunk_size=8192): f.write(chunk) logger.info("Download complete.") except Exception as e: logger.error(f"Failed to download model: {e}") raise try: self.model = ModelLoader().load_from_file(model_path) self.model.eval() self.model.to(Config.DEVICE) # Optimization: torch.compile if not self.compiled: try: # 'reduce-overhead' uses CUDA graphs, so only use it on CUDA if Config.DEVICE == 'cuda': self.model = torch.compile(self.model, mode='reduce-overhead') logger.info("[INFO] torch.compile enabled (reduce-overhead mode)") elif os.name == 'nt' and Config.DEVICE == 'cpu': # Windows requires MSVC for Inductor (default cpu backend) # We skip it to avoid "Compiler: cl is not found" error unless user has it. logger.info("[INFO] Skipping torch.compile on Windows CPU to avoid MSVC requirement.") elif (psutil.cpu_count(logical=False) or 0) < 4 and Config.DEVICE == 'cpu': # Skip compilation on weak CPUs (e.g. HF Spaces Free Tier) to avoid long startup times logger.info("[INFO] Skipping torch.compile on low-core CPU to prevent timeout.") else: # On Linux/Mac CPU, use default mode or skip if problematic. Default is usually safe. self.model = torch.compile(self.model) logger.info("[SUCCESS] torch.compile enabled (default mode)") self.compiled = True except Exception as e: logger.warning(f"[WARNING] torch.compile not available or failed: {e}") self.compiled = True # Mark as tried logger.info(f"{self.name} loaded successfully.") except Exception as e: logger.error(f"Failed to load model architecture: {e}") raise def upscale(self, image: Image.Image, **kwargs) -> Image.Image: if self.model is None: self.load() logger.info(f"Starting inference with {self.name}...") start_time = time.time() img_np = np.array(image).astype(np.float32) / 255.0 img_tensor = torch.from_numpy(img_np).permute(2, 0, 1).unsqueeze(0).to(Config.DEVICE) # Optimization: Dynamic Tiling h, w = img_np.shape[:2] tile_config = select_tile_config(h, w) logger.info(f"Using tile config: {tile_config}") # Optimization: Mixed Precision (AMP) # Use bfloat16 for CPU if supported, else float32 (autocast handles this mostly) # For CUDA, float16 is standard. dtype = torch.float16 if Config.DEVICE == 'cuda' else torch.bfloat16 try: # Explicitly disable autocast on CPU for RealESRGAN to avoid "PythonFallbackKernel" errors # This seems to be a regression in recent PyTorch versions on CPU with some ops context = torch.autocast(device_type=Config.DEVICE, dtype=dtype) if Config.DEVICE != 'cpu' else torch.no_grad() with context: if tile_config['tile'] > 0: output_tensor = manual_tile_upscale( self.model, img_tensor, tile_size=tile_config['tile'], tile_pad=tile_config['tile_pad'], scale=2 ) else: output_tensor = self.model(img_tensor) # type: ignore except Exception as e: logger.warning(f"AMP/Tiling failed, falling back to standard FP32: {e}") # Fallback to standard execution output_tensor = self.model(img_tensor) # type: ignore output_np = output_tensor.squeeze(0).permute(1, 2, 0).clamp(0, 1).float().cpu().numpy() output_np = (output_np * 255.0).round().astype(np.uint8) elapsed = time.time() - start_time logger.info(f"Inference finished in {elapsed:.2f}s") # Benchmark info (from doc) output_megapixels = (output_np.shape[0] * output_np.shape[1]) / (1024 ** 2) throughput = output_megapixels / elapsed logger.info(f"Speed: {throughput:.2f} MP/s") return Image.fromarray(output_np) class SpanStrategy(UpscalerStrategy): def __init__(self): super().__init__() self.name = "SPAN (NomosUni) x2" self.compiled = False def load(self) -> None: if self.model is None: logger.info(f"Loading {self.name}...") Config.ensure_model_dir() model_path = os.path.join(Config.MODEL_DIR, Config.SPAN_FILENAME) if not os.path.exists(model_path): logger.info(f"Downloading {Config.SPAN_FILENAME}...") try: response = requests.get(Config.SPAN_URL, stream=True) response.raise_for_status() with open(model_path, 'wb') as f: for chunk in response.iter_content(chunk_size=8192): f.write(chunk) logger.info("Download complete.") except Exception as e: logger.error(f"Failed to download model: {e}") raise try: self.model = ModelLoader().load_from_file(model_path) self.model.eval() self.model.to(Config.DEVICE) # Optimization: torch.compile if not self.compiled: try: if Config.DEVICE == 'cuda': self.model = torch.compile(self.model, mode='reduce-overhead') logger.info("[INFO] torch.compile enabled (reduce-overhead mode)") elif os.name == 'nt' and Config.DEVICE == 'cpu': logger.info("[INFO] Skipping torch.compile on Windows CPU.") elif (psutil.cpu_count(logical=False) or 0) < 4 and Config.DEVICE == 'cpu': logger.info("[INFO] Skipping torch.compile on low-core CPU.") else: # SPAN architecture uses .data.clone() in forward pass which breaks torch.compile/inductor logger.info("[INFO] Skipping torch.compile for SPAN (incompatible architecture).") # self.model = torch.compile(self.model) self.compiled = True except Exception: self.compiled = True logger.info(f"{self.name} loaded successfully.") except Exception as e: logger.error(f"Failed to load model architecture: {e}") raise def upscale(self, image: Image.Image, **kwargs) -> Image.Image: if self.model is None: self.load() logger.info(f"Starting inference with {self.name}...") start_time = time.time() img_np = np.array(image).astype(np.float32) / 255.0 img_tensor = torch.from_numpy(img_np).permute(2, 0, 1).unsqueeze(0).to(Config.DEVICE) # SPAN is very efficient, but we still use tiling for safety on huge images h, w = img_np.shape[:2] tile_config = select_tile_config(h, w) # Disable AMP for SPAN on CPU to avoid "UntypedStorage" weakref errors in inductor # SPAN architecture seems sensitive to autocast + compile on CPU dtype = torch.float32 if Config.DEVICE == 'cpu' else torch.float16 try: # Only use autocast if not CPU or if explicitly desired context = torch.autocast(device_type=Config.DEVICE, dtype=dtype) if Config.DEVICE != 'cpu' else torch.no_grad() with context: if tile_config['tile'] > 0: output_tensor = manual_tile_upscale( self.model, img_tensor, tile_size=tile_config['tile'], tile_pad=tile_config['tile_pad'], scale=2 ) else: output_tensor = self.model(img_tensor) # type: ignore except Exception as e: logger.warning(f"AMP/Tiling failed, falling back: {e}") output_tensor = self.model(img_tensor) # type: ignore output_np = output_tensor.squeeze(0).permute(1, 2, 0).clamp(0, 1).float().cpu().numpy() output_np = (output_np * 255.0).round().astype(np.uint8) elapsed = time.time() - start_time logger.info(f"Inference finished in {elapsed:.2f}s") return Image.fromarray(output_np) class HatsStrategy(UpscalerStrategy): def __init__(self): super().__init__() self.name = "HAT-S x4" self.compiled = False def load(self) -> None: if self.model is None: logger.info(f"Loading {self.name}...") Config.ensure_model_dir() model_path = os.path.join(Config.MODEL_DIR, Config.HATS_FILENAME) if not os.path.exists(model_path): logger.info(f"Downloading {Config.HATS_FILENAME}...") try: response = requests.get(Config.HATS_URL, stream=True) response.raise_for_status() with open(model_path, 'wb') as f: for chunk in response.iter_content(chunk_size=8192): f.write(chunk) logger.info("Download complete.") except Exception as e: logger.error(f"Failed to download model: {e}") raise try: self.model = ModelLoader().load_from_file(model_path) self.model.eval() self.model.to(Config.DEVICE) if not self.compiled: try: if Config.DEVICE == 'cuda': self.model = torch.compile(self.model, mode='reduce-overhead') elif os.name == 'nt' and Config.DEVICE == 'cpu': pass elif (psutil.cpu_count(logical=False) or 0) < 4 and Config.DEVICE == 'cpu': pass else: # HAT architecture also triggers "UntypedStorage" weakref errors with inductor on CPU logger.info("[INFO] Skipping torch.compile for HAT-S (incompatible architecture).") # self.model = torch.compile(self.model) self.compiled = True except Exception: self.compiled = True logger.info(f"{self.name} loaded successfully.") except Exception as e: logger.error(f"Failed to load model architecture: {e}") raise def upscale(self, image: Image.Image, **kwargs) -> Image.Image: if self.model is None: self.load() logger.info(f"Starting inference with {self.name}...") start_time = time.time() img_np = np.array(image).astype(np.float32) / 255.0 img_tensor = torch.from_numpy(img_np).permute(2, 0, 1).unsqueeze(0).to(Config.DEVICE) h, w = img_np.shape[:2] tile_config = select_tile_config(h, w) dtype = torch.float16 if Config.DEVICE == 'cuda' else torch.float32 try: context = torch.autocast(device_type=Config.DEVICE, dtype=dtype) if Config.DEVICE != 'cpu' else torch.no_grad() with context: if tile_config['tile'] > 0: output_tensor = manual_tile_upscale( self.model, img_tensor, tile_size=tile_config['tile'], tile_pad=tile_config['tile_pad'], scale=4 # HAT-S is x4 ) else: output_tensor = self.model(img_tensor) # type: ignore except Exception as e: logger.warning(f"AMP/Tiling failed, falling back: {e}") output_tensor = self.model(img_tensor) # type: ignore output_np = output_tensor.squeeze(0).permute(1, 2, 0).clamp(0, 1).float().cpu().numpy() output_np = (output_np * 255.0).round().astype(np.uint8) elapsed = time.time() - start_time logger.info(f"Inference finished in {elapsed:.2f}s") return Image.fromarray(output_np) # --- Model Manager (Singleton-ish) --- class UpscalerManager: """Manages model lifecycle and selection.""" def __init__(self): self.strategies: Dict[str, UpscalerStrategy] = { "SPAN (NomosUni) x2": SpanStrategy(), "RealESRGAN x2": RealESRGANStrategy(), "HAT-S x4": HatsStrategy() } self.current_model_name: Optional[str] = None def get_strategy(self, name: str) -> UpscalerStrategy: if name not in self.strategies: raise ValueError(f"Model {name} not found.") # Memory Optimization for Free Tier (16GB RAM limit): # Ensure only one model is loaded at a time. if self.current_model_name != name: if self.current_model_name is not None: logger.info(f"Switching models: Unloading {self.current_model_name}...") self.strategies[self.current_model_name].unload() self.current_model_name = name return self.strategies[name] def unload_all(self): """Unload all models to free memory.""" for strategy in self.strategies.values(): strategy.unload() gc.collect() logger.info("All models unloaded.") manager = UpscalerManager() # --- Gradio Interface Logic --- def process_image(input_img: Image.Image, model_name: str, output_format: str) -> Tuple[Optional[str], str, str]: if input_img is None: return None, get_logs(), get_system_usage() try: strategy = manager.get_strategy(model_name) output_img = strategy.upscale(input_img) # Save to temp file with correct extension output_path = f"output.{output_format.lower()}" # Convert to RGB if saving as JPEG (doesn't support alpha) if output_format.lower() in ['jpeg', 'jpg'] and output_img.mode == 'RGBA': output_img = output_img.convert('RGB') output_img.save(output_path, format=output_format) # Explicit GC after heavy operations gc.collect() return output_path, get_logs(), get_system_usage() except Exception as e: error_msg = f"Critical Error: {str(e)}\n{traceback.format_exc()}" logger.error(error_msg) return None, get_logs() + "\n\n" + error_msg, get_system_usage() def unload_models(): manager.unload_all() return get_logs(), get_system_usage() # --- UI Construction --- desc = """ # Universal Upscaler Pro (CPU Optimized) This application provides state-of-the-art (SOTA) image upscaling running entirely on CPU, optimized for free-tier cloud environments. ### Available Models | Model | Scale | Best For | License | | :--- | :--- | :--- | :--- | | **SPAN (NomosUni)** | x2 | **Speed & General Use**. Extremely fast, parameter-free attention network. | Apache 2.0 | | **RealESRGAN** | x2 | **Robustness**. Excellent at removing JPEG artifacts and noise. | BSD 3-Clause | | **HAT-S** | x4 | **Texture Detail**. Hybrid Attention Transformer for high-fidelity restoration. | MIT | ### Attributions & Credits * **Real-ESRGAN**: [Wang et al., 2021](https://github.com/xinntao/Real-ESRGAN). *Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data*. * **SPAN**: [Zhang et al., 2023](https://github.com/hongyuanyu/SPAN). *Swift Parameter-free Attention Network for Efficient Super-Resolution*. * **HAT**: [Chen et al., 2023](https://github.com/XPixelGroup/HAT). *Activating Activation Functions for Image Restoration*. * **NomosUni**: Custom SPAN training by [Phhofm](https://github.com/Phhofm). """ with gr.Blocks(title="Universal Upscaler Pro") as iface: gr.Markdown(desc) with gr.Row(): with gr.Column(scale=1, min_width=300): input_image = gr.Image(type="pil", label="Input Image", height=400) with gr.Row(): model_selector = gr.Dropdown( choices=list(manager.strategies.keys()), value="SPAN (NomosUni) x2", label="Model Architecture", scale=2 ) output_format = gr.Dropdown( choices=["PNG", "JPEG", "WEBP"], value="PNG", label="Output Format", scale=1 ) submit_btn = gr.Button("Upscale Image", variant="primary", size="lg") with gr.Accordion("Advanced Settings", open=False): unload_btn = gr.Button("Unload All Models (Free RAM)", variant="secondary") system_info = gr.Label(value=get_system_usage(), label="System Status") with gr.Column(scale=1, min_width=300): output_image = gr.Image(type="filepath", label="Upscaled Result", height=400) logs_output = gr.TextArea(label="Execution Logs", interactive=False, lines=8) # Event Wiring submit_btn.click( fn=process_image, inputs=[input_image, model_selector, output_format], outputs=[output_image, logs_output, system_info] ) unload_btn.click( fn=unload_models, inputs=[], outputs=[logs_output, system_info] ) iface.launch()