import gradio as gr import torch from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale import StableDiffusionUpscalePipeline from transformers import AutoImageProcessor, Swin2SRForImageSuperResolution import gc from PIL import Image import numpy as np import logging import io import os import requests from spandrel import ModelLoader from abc import ABC, abstractmethod from typing import Optional, Tuple, Dict import psutil import time import traceback # --- Configuration --- class Config: """Configuration settings for the application.""" MODEL_DIR = "weights" REALESRGAN_URL = "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth" REALESRGAN_FILENAME = "RealESRGAN_x2plus.pth" SWIN2SR_ID = "caidas/swin2SR-classical-sr-x2-64" SD_ID = "stabilityai/stable-diffusion-x4-upscaler" # SOTA Models (2025) SPAN_URL = "https://huggingface.co/Phips/2xNomosUni_span_multijpg/resolve/main/2xNomosUni_span_multijpg.safetensors" SPAN_FILENAME = "2xNomosUni_span_multijpg.safetensors" HATS_URL = "https://huggingface.co/Phips/4xNomos8kSCHAT-S/resolve/main/4xNomos8kSCHAT-S.safetensors" HATS_FILENAME = "4xNomos8kSCHAT-S.safetensors" MAX_IMAGE_SIZE_SD = 512 # Max dimension for SD input to prevent OOM DEVICE = "cpu" # Force CPU for this demo, can be "cuda" if available @staticmethod def ensure_model_dir(): if not os.path.exists(Config.MODEL_DIR): os.makedirs(Config.MODEL_DIR) # --- Logging Setup --- class LogCapture(io.StringIO): """Custom StringIO to capture logs.""" pass log_capture_string = LogCapture() ch = logging.StreamHandler(log_capture_string) ch.setLevel(logging.INFO) formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s') ch.setFormatter(formatter) logger = logging.getLogger("UpscalerApp") logger.setLevel(logging.INFO) logger.addHandler(ch) def get_logs() -> str: """Retrieve captured logs.""" return log_capture_string.getvalue() # --- System Monitoring --- def get_system_usage() -> str: """Returns current CPU and RAM usage.""" cpu_percent = psutil.cpu_percent() ram_percent = psutil.virtual_memory().percent ram_used_gb = psutil.virtual_memory().used / (1024 ** 3) return f"CPU: {cpu_percent}% | RAM: {ram_percent}% ({ram_used_gb:.1f} GB used)" # --- Abstract Base Class for Models --- class UpscalerStrategy(ABC): """Abstract base class for upscaling strategies.""" def __init__(self): self.model = None self.name = "Unknown" @abstractmethod def load(self) -> None: """Load the model into memory.""" pass @abstractmethod def upscale(self, image: Image.Image, **kwargs) -> Image.Image: """Upscale the given image.""" pass def unload(self) -> None: """Unload the model to free memory.""" if self.model is not None: del self.model self.model = None gc.collect() logger.info(f"Unloaded {self.name}") # --- Helper Functions for Optimization --- def manual_tile_upscale(model, img_tensor, tile_size=256, tile_pad=10, scale=2): """ Low-level tiling implementation for custom models. Prevents OOM by processing image in chunks. """ B, C, H, W = img_tensor.shape # Calculate tile dimensions tile_h = (H + tile_size - 1) // tile_size tile_w = (W + tile_size - 1) // tile_size output = torch.zeros(B, C, H * scale, W * scale, device=img_tensor.device, dtype=img_tensor.dtype) for th in range(tile_h): for tw in range(tile_w): # Calculate input tile coordinates with padding x1 = th * tile_size y1 = tw * tile_size x2 = min((th + 1) * tile_size, H) y2 = min((tw + 1) * tile_size, W) # Add halo for context x1_pad = max(0, x1 - tile_pad) y1_pad = max(0, y1 - tile_pad) x2_pad = min(H, x2 + tile_pad) y2_pad = min(W, y2 + tile_pad) # Extract padded tile tile = img_tensor[:, :, x1_pad:x2_pad, y1_pad:y2_pad] # Process tile with torch.no_grad(): tile_out = model(tile) # Calculate output crop region (remove halo) halo_x1 = (x1 - x1_pad) * scale halo_y1 = (y1 - y1_pad) * scale out_x2 = halo_x1 + (x2 - x1) * scale out_y2 = halo_y1 + (y2 - y1) * scale # Place in output output[:, :, x1*scale:x2*scale, y1*scale:y2*scale] = \ tile_out[:, :, halo_x1:out_x2, halo_y1:out_y2] return output def select_tile_config(height, width): """ Dynamically select tile size based on image resolution. """ megapixels = (height * width) / (1024 ** 2) if megapixels < 2: # < 1080p return {'tile': 512, 'tile_pad': 10} elif megapixels < 6: # < 4K return {'tile': 384, 'tile_pad': 15} elif megapixels < 16: # < 8K return {'tile': 256, 'tile_pad': 20} else: # 8K+ return {'tile': 128, 'tile_pad': 25} # --- Concrete Implementations --- class RealESRGANStrategy(UpscalerStrategy): def __init__(self): super().__init__() self.name = "RealESRGAN x2" self.compiled = False def load(self) -> None: if self.model is None: logger.info(f"Loading {self.name}...") Config.ensure_model_dir() model_path = os.path.join(Config.MODEL_DIR, Config.REALESRGAN_FILENAME) if not os.path.exists(model_path): logger.info(f"Downloading {Config.REALESRGAN_FILENAME}...") try: response = requests.get(Config.REALESRGAN_URL, stream=True) response.raise_for_status() with open(model_path, 'wb') as f: for chunk in response.iter_content(chunk_size=8192): f.write(chunk) logger.info("Download complete.") except Exception as e: logger.error(f"Failed to download model: {e}") raise try: self.model = ModelLoader().load_from_file(model_path) self.model.eval() self.model.to(Config.DEVICE) # Optimization: torch.compile if not self.compiled: try: # 'reduce-overhead' uses CUDA graphs, so only use it on CUDA if Config.DEVICE == 'cuda': self.model = torch.compile(self.model, mode='reduce-overhead') logger.info("✓ torch.compile enabled (reduce-overhead mode)") elif os.name == 'nt' and Config.DEVICE == 'cpu': # Windows requires MSVC for Inductor (default cpu backend) # We skip it to avoid "Compiler: cl is not found" error unless user has it. logger.info("ℹ Skipping torch.compile on Windows CPU to avoid MSVC requirement.") elif (psutil.cpu_count(logical=False) or 0) < 4 and Config.DEVICE == 'cpu': # Skip compilation on weak CPUs (e.g. HF Spaces Free Tier) to avoid long startup times logger.info("ℹ Skipping torch.compile on low-core CPU to prevent timeout.") else: # On Linux/Mac CPU, use default mode or skip if problematic. Default is usually safe. self.model = torch.compile(self.model) logger.info("✓ torch.compile enabled (default mode)") self.compiled = True except Exception as e: logger.warning(f"⚠ torch.compile not available or failed: {e}") self.compiled = True # Mark as tried logger.info(f"{self.name} loaded successfully.") except Exception as e: logger.error(f"Failed to load model architecture: {e}") raise def upscale(self, image: Image.Image, **kwargs) -> Image.Image: if self.model is None: self.load() logger.info(f"Starting inference with {self.name}...") start_time = time.time() img_np = np.array(image).astype(np.float32) / 255.0 img_tensor = torch.from_numpy(img_np).permute(2, 0, 1).unsqueeze(0).to(Config.DEVICE) # Optimization: Dynamic Tiling h, w = img_np.shape[:2] tile_config = select_tile_config(h, w) logger.info(f"Using tile config: {tile_config}") # Optimization: Mixed Precision (AMP) # Use bfloat16 for CPU if supported, else float32 (autocast handles this mostly) # For CUDA, float16 is standard. dtype = torch.float16 if Config.DEVICE == 'cuda' else torch.bfloat16 try: with torch.autocast(device_type=Config.DEVICE, dtype=dtype): if tile_config['tile'] > 0: output_tensor = manual_tile_upscale( self.model, img_tensor, tile_size=tile_config['tile'], tile_pad=tile_config['tile_pad'], scale=2 ) else: output_tensor = self.model(img_tensor) # type: ignore except Exception as e: logger.warning(f"AMP/Tiling failed, falling back to standard FP32: {e}") # Fallback to standard execution output_tensor = self.model(img_tensor) # type: ignore output_np = output_tensor.squeeze(0).permute(1, 2, 0).clamp(0, 1).float().cpu().numpy() output_np = (output_np * 255.0).round().astype(np.uint8) elapsed = time.time() - start_time logger.info(f"Inference finished in {elapsed:.2f}s") # Benchmark info (from doc) output_megapixels = (output_np.shape[0] * output_np.shape[1]) / (1024 ** 2) throughput = output_megapixels / elapsed logger.info(f"Speed: {throughput:.2f} MP/s") return Image.fromarray(output_np) class Swin2SRStrategy(UpscalerStrategy): def __init__(self): super().__init__() self.name = "Swin2SR x2" self.processor = None def load(self) -> None: if self.model is None: logger.info(f"Loading {self.name}...") try: self.processor = AutoImageProcessor.from_pretrained(Config.SWIN2SR_ID) model = Swin2SRForImageSuperResolution.from_pretrained(Config.SWIN2SR_ID) self.model = model.to(Config.DEVICE) # type: ignore logger.info(f"{self.name} loaded successfully.") except Exception as e: logger.error(f"Failed to load Swin2SR: {e}") raise def upscale(self, image: Image.Image, **kwargs) -> Image.Image: if self.model is None or self.processor is None: self.load() logger.info(f"Starting inference with {self.name}...") start_time = time.time() if self.processor is None: raise ValueError("Processor not loaded") inputs = self.processor(images=image, return_tensors="pt").to(Config.DEVICE) with torch.no_grad(): outputs = self.model(**inputs) output = outputs.reconstruction.data.squeeze().float().cpu().clamp_(0, 1).numpy() output = np.moveaxis(output, source=0, destination=-1) output = (output * 255.0).round().astype(np.uint8) logger.info(f"Inference finished in {time.time() - start_time:.2f}s") return Image.fromarray(output) class StableDiffusionStrategy(UpscalerStrategy): def __init__(self): super().__init__() self.name = "Stable Diffusion x4" def load(self) -> None: if self.model is None: logger.info(f"Loading {self.name} (this may take time)...") try: self.model = StableDiffusionUpscalePipeline.from_pretrained( Config.SD_ID, torch_dtype=torch.float32, low_cpu_mem_usage=True ) # Optimizations for CPU self.model.enable_attention_slicing("max") self.model.enable_vae_tiling() logger.info(f"{self.name} loaded successfully.") except Exception as e: logger.error(f"Failed to load Stable Diffusion: {e}") raise def upscale(self, image: Image.Image, **kwargs) -> Image.Image: if self.model is None: self.load() prompt = kwargs.get("prompt", "high quality, detailed") # Pre-check size if max(image.size) > Config.MAX_IMAGE_SIZE_SD: ratio = Config.MAX_IMAGE_SIZE_SD / max(image.size) new_size = (int(image.size[0] * ratio), int(image.size[1] * ratio)) image = image.resize(new_size, Image.Resampling.LANCZOS) logger.warning(f"Resized input to {new_size} to prevent OOM on CPU.") logger.info(f"Starting inference with {self.name}...") start_time = time.time() generator = torch.manual_seed(42) output = self.model( prompt=prompt, image=image, num_inference_steps=20, guidance_scale=7.0, generator=generator ).images[0] # type: ignore logger.info(f"Inference finished in {time.time() - start_time:.2f}s") return output class SpanStrategy(UpscalerStrategy): def __init__(self): super().__init__() self.name = "SPAN (NomosUni) x2" self.compiled = False def load(self) -> None: if self.model is None: logger.info(f"Loading {self.name}...") Config.ensure_model_dir() model_path = os.path.join(Config.MODEL_DIR, Config.SPAN_FILENAME) if not os.path.exists(model_path): logger.info(f"Downloading {Config.SPAN_FILENAME}...") try: response = requests.get(Config.SPAN_URL, stream=True) response.raise_for_status() with open(model_path, 'wb') as f: for chunk in response.iter_content(chunk_size=8192): f.write(chunk) logger.info("Download complete.") except Exception as e: logger.error(f"Failed to download model: {e}") raise try: self.model = ModelLoader().load_from_file(model_path) self.model.eval() self.model.to(Config.DEVICE) # Optimization: torch.compile if not self.compiled: try: if Config.DEVICE == 'cuda': self.model = torch.compile(self.model, mode='reduce-overhead') logger.info("✓ torch.compile enabled (reduce-overhead mode)") elif os.name == 'nt' and Config.DEVICE == 'cpu': logger.info("ℹ Skipping torch.compile on Windows CPU.") elif (psutil.cpu_count(logical=False) or 0) < 4 and Config.DEVICE == 'cpu': logger.info("ℹ Skipping torch.compile on low-core CPU.") else: self.model = torch.compile(self.model) logger.info("✓ torch.compile enabled (default mode)") self.compiled = True except Exception as e: logger.warning(f"⚠ torch.compile failed: {e}") self.compiled = True logger.info(f"{self.name} loaded successfully.") except Exception as e: logger.error(f"Failed to load model architecture: {e}") raise def upscale(self, image: Image.Image, **kwargs) -> Image.Image: if self.model is None: self.load() logger.info(f"Starting inference with {self.name}...") start_time = time.time() img_np = np.array(image).astype(np.float32) / 255.0 img_tensor = torch.from_numpy(img_np).permute(2, 0, 1).unsqueeze(0).to(Config.DEVICE) # SPAN is very efficient, but we still use tiling for safety on huge images h, w = img_np.shape[:2] tile_config = select_tile_config(h, w) dtype = torch.float16 if Config.DEVICE == 'cuda' else torch.bfloat16 try: with torch.autocast(device_type=Config.DEVICE, dtype=dtype): if tile_config['tile'] > 0: output_tensor = manual_tile_upscale( self.model, img_tensor, tile_size=tile_config['tile'], tile_pad=tile_config['tile_pad'], scale=2 ) else: output_tensor = self.model(img_tensor) # type: ignore except Exception as e: logger.warning(f"AMP/Tiling failed, falling back: {e}") output_tensor = self.model(img_tensor) # type: ignore output_np = output_tensor.squeeze(0).permute(1, 2, 0).clamp(0, 1).float().cpu().numpy() output_np = (output_np * 255.0).round().astype(np.uint8) elapsed = time.time() - start_time logger.info(f"Inference finished in {elapsed:.2f}s") return Image.fromarray(output_np) class HatsStrategy(UpscalerStrategy): def __init__(self): super().__init__() self.name = "HAT-S x4" self.compiled = False def load(self) -> None: if self.model is None: logger.info(f"Loading {self.name}...") Config.ensure_model_dir() model_path = os.path.join(Config.MODEL_DIR, Config.HATS_FILENAME) if not os.path.exists(model_path): logger.info(f"Downloading {Config.HATS_FILENAME}...") try: response = requests.get(Config.HATS_URL, stream=True) response.raise_for_status() with open(model_path, 'wb') as f: for chunk in response.iter_content(chunk_size=8192): f.write(chunk) logger.info("Download complete.") except Exception as e: logger.error(f"Failed to download model: {e}") raise try: self.model = ModelLoader().load_from_file(model_path) self.model.eval() self.model.to(Config.DEVICE) if not self.compiled: try: if Config.DEVICE == 'cuda': self.model = torch.compile(self.model, mode='reduce-overhead') elif os.name == 'nt' and Config.DEVICE == 'cpu': pass elif (psutil.cpu_count(logical=False) or 0) < 4 and Config.DEVICE == 'cpu': pass else: self.model = torch.compile(self.model) self.compiled = True except Exception: self.compiled = True logger.info(f"{self.name} loaded successfully.") except Exception as e: logger.error(f"Failed to load model architecture: {e}") raise def upscale(self, image: Image.Image, **kwargs) -> Image.Image: if self.model is None: self.load() logger.info(f"Starting inference with {self.name}...") start_time = time.time() img_np = np.array(image).astype(np.float32) / 255.0 img_tensor = torch.from_numpy(img_np).permute(2, 0, 1).unsqueeze(0).to(Config.DEVICE) h, w = img_np.shape[:2] tile_config = select_tile_config(h, w) dtype = torch.float16 if Config.DEVICE == 'cuda' else torch.bfloat16 try: with torch.autocast(device_type=Config.DEVICE, dtype=dtype): if tile_config['tile'] > 0: output_tensor = manual_tile_upscale( self.model, img_tensor, tile_size=tile_config['tile'], tile_pad=tile_config['tile_pad'], scale=4 # HAT-S is x4 ) else: output_tensor = self.model(img_tensor) # type: ignore except Exception as e: logger.warning(f"AMP/Tiling failed, falling back: {e}") output_tensor = self.model(img_tensor) # type: ignore output_np = output_tensor.squeeze(0).permute(1, 2, 0).clamp(0, 1).float().cpu().numpy() output_np = (output_np * 255.0).round().astype(np.uint8) elapsed = time.time() - start_time logger.info(f"Inference finished in {elapsed:.2f}s") return Image.fromarray(output_np) # --- Model Manager (Singleton-ish) --- class UpscalerManager: """Manages model lifecycle and selection.""" def __init__(self): self.strategies: Dict[str, UpscalerStrategy] = { "SPAN (NomosUni) x2": SpanStrategy(), "RealESRGAN x2": RealESRGANStrategy(), "HAT-S x4": HatsStrategy(), "Swin2SR x2": Swin2SRStrategy(), "Stable Diffusion x4": StableDiffusionStrategy() } self.current_model_name: Optional[str] = None def get_strategy(self, name: str) -> UpscalerStrategy: if name not in self.strategies: raise ValueError(f"Model {name} not found.") # Memory Optimization for Free Tier (16GB RAM limit): # Ensure only one model is loaded at a time. if self.current_model_name != name: if self.current_model_name is not None: logger.info(f"Switching models: Unloading {self.current_model_name}...") self.strategies[self.current_model_name].unload() self.current_model_name = name return self.strategies[name] def unload_all(self): """Unload all models to free memory.""" for strategy in self.strategies.values(): strategy.unload() gc.collect() logger.info("All models unloaded.") manager = UpscalerManager() # --- Gradio Interface Logic --- def process_image(input_img: Image.Image, model_name: str, prompt: str) -> Tuple[Optional[Image.Image], str, str]: if input_img is None: return None, get_logs(), get_system_usage() try: strategy = manager.get_strategy(model_name) # Optional: Unload others if memory is tight (simple logic here) # For now, we just rely on the user or OS, but in prod we might auto-unload. output = strategy.upscale(input_img, prompt=prompt) # Explicit GC after heavy operations gc.collect() return output, get_logs(), get_system_usage() except Exception as e: error_msg = f"Critical Error: {str(e)}\n{traceback.format_exc()}" logger.error(error_msg) # Return the error message in the logs output so the user sees it return None, get_logs() + "\n\n" + error_msg, get_system_usage() def unload_models(): manager.unload_all() return get_logs(), get_system_usage() # --- UI Construction --- desc = """ ### 🚀 Enterprise-Grade Universal Upscaler (SOTA 2025) Select a specialized model to upscale your image. * **SPAN (NomosUni) x2**: ⚡ **SOTA Speed**. Fastest CPU model. Best for general use. * **RealESRGAN x2**: 🛡️ **Robust**. Best for removing JPEG artifacts and noise. * **HAT-S x4**: 💎 **SOTA Quality**. Best texture detail (slower). * **Swin2SR x2**: 🎯 High fidelity, removes compression artifacts. * **Stable Diffusion x4**: 🎨 Generative upscaling. Adds missing details (slow, high RAM). """ with gr.Blocks(title="Universal Upscaler Pro") as iface: gr.Markdown(desc) with gr.Row(): with gr.Column(scale=1): input_image = gr.Image(type="pil", label="Input Image") with gr.Group(): model_selector = gr.Dropdown( choices=list(manager.strategies.keys()), value="SPAN (NomosUni) x2", label="Select Model Architecture" ) prompt_input = gr.Textbox( label="Prompt (Stable Diffusion Only)", value="highly detailed, 4k, sharp", placeholder="Describe the image content..." ) with gr.Accordion("Advanced Settings", open=False): gr.Markdown("Memory Management") unload_btn = gr.Button("Unload All Models (Free RAM)", variant="secondary") submit_btn = gr.Button("✨ Upscale Image", variant="primary", size="lg") system_info = gr.Label(value=get_system_usage(), label="System Status") with gr.Column(scale=1): output_image = gr.Image(type="pil", label="Upscaled Result") logs_output = gr.TextArea(label="Execution Logs", interactive=False, lines=10) # Event Wiring submit_btn.click( fn=process_image, inputs=[input_image, model_selector, prompt_input], outputs=[output_image, logs_output, system_info] ) unload_btn.click( fn=unload_models, inputs=[], outputs=[logs_output, system_info] ) # Auto-refresh system info every 2 seconds (optional, can be heavy on UI) # iface.load(get_system_usage, None, system_info, every=2) iface.launch()