import sys import gradio as gr import os import numpy as np import cv2 import time import shutil from pathlib import Path from einops import rearrange from typing import Union # Force unbuffered output for HF Spaces logs os.environ['PYTHONUNBUFFERED'] = '1' # Configure logging FIRST before any other imports import logging logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', handlers=[ logging.StreamHandler(sys.stdout) ] ) logger = logging.getLogger(__name__) logger.info("=" * 50) logger.info("Starting application initialization...") logger.info("=" * 50) sys.stdout.flush() try: import spaces logger.info("✅ HF Spaces module imported successfully") except ImportError: logger.warning("⚠️ HF Spaces module not available, using mock") class spaces: @staticmethod def GPU(func=None, duration=None): def decorator(f): return f return decorator if func is None else func sys.stdout.flush() logger.info("Importing torch...") sys.stdout.flush() import torch logger.info(f"✅ Torch imported. Version: {torch.__version__}, CUDA available: {torch.cuda.is_available()}") sys.stdout.flush() import torch.nn.functional as F import torchvision.transforms as T from concurrent.futures import ThreadPoolExecutor import atexit import uuid logger.info("Importing decord...") sys.stdout.flush() import decord logger.info("✅ Decord imported successfully") sys.stdout.flush() from PIL import Image logger.info("Importing SpaTrack models...") sys.stdout.flush() try: from models.SpaTrackV2.models.vggt4track.models.vggt_moe import VGGT4Track from models.SpaTrackV2.models.vggt4track.utils.load_fn import preprocess_image from models.SpaTrackV2.models.predictor import Predictor from models.SpaTrackV2.models.utils import get_points_on_a_grid logger.info("✅ SpaTrack models imported successfully") except Exception as e: logger.error(f"❌ Failed to import SpaTrack models: {e}") raise sys.stdout.flush() # TTM imports (optional - will be loaded on demand) logger.info("Checking TTM (diffusers) availability...") sys.stdout.flush() TTM_COG_AVAILABLE = False TTM_WAN_AVAILABLE = False try: from diffusers import CogVideoXImageToVideoPipeline from diffusers.utils import export_to_video, load_image from diffusers.schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler from diffusers.utils.torch_utils import randn_tensor from diffusers.video_processor import VideoProcessor TTM_COG_AVAILABLE = True logger.info("✅ CogVideoX TTM available") except ImportError as e: logger.info(f"ℹ️ CogVideoX TTM not available: {e}") sys.stdout.flush() try: from diffusers import AutoencoderKLWan, WanTransformer3DModel from diffusers.schedulers import FlowMatchEulerDiscreteScheduler from diffusers.pipelines.wan.pipeline_wan_i2v import WanImageToVideoPipeline, retrieve_latents from diffusers.pipelines.wan.pipeline_output import WanPipelineOutput if not TTM_COG_AVAILABLE: from diffusers.utils import export_to_video, load_image from diffusers.utils.torch_utils import randn_tensor from diffusers.video_processor import VideoProcessor TTM_WAN_AVAILABLE = True logger.info("✅ Wan TTM available") except ImportError as e: logger.info(f"ℹ️ Wan TTM not available: {e}") sys.stdout.flush() TTM_AVAILABLE = TTM_COG_AVAILABLE or TTM_WAN_AVAILABLE if not TTM_AVAILABLE: logger.warning("⚠️ Diffusers not available. TTM features will be disabled.") else: logger.info(f"TTM Status - CogVideoX: {TTM_COG_AVAILABLE}, Wan: {TTM_WAN_AVAILABLE}") sys.stdout.flush() # Constants MAX_FRAMES = 80 OUTPUT_FPS = 24 RENDER_WIDTH = 512 RENDER_HEIGHT = 384 # Camera movement types CAMERA_MOVEMENTS = [ "static", "move_forward", "move_backward", "move_left", "move_right", "move_up", "move_down" ] # TTM Constants TTM_COG_MODEL_ID = "THUDM/CogVideoX-5b-I2V" TTM_WAN_MODEL_ID = "Wan-AI/Wan2.2-I2V-A14B-Diffusers" TTM_DTYPE = torch.bfloat16 TTM_DEFAULT_NUM_FRAMES = 49 TTM_DEFAULT_NUM_INFERENCE_STEPS = 50 # TTM Model choices TTM_MODELS = [] if TTM_COG_AVAILABLE: TTM_MODELS.append("CogVideoX-5B") if TTM_WAN_AVAILABLE: TTM_MODELS.append("Wan2.2-14B (Recommended)") # Global model instances (lazy loaded for HF Spaces GPU compatibility) vggt4track_model = None tracker_model = None ttm_cog_pipeline = None ttm_wan_pipeline = None MODELS_LOADED = False def load_video_to_tensor(video_path: str) -> torch.Tensor: """Returns a video tensor from a video file. shape [1, C, T, H, W], [0, 1] range.""" cap = cv2.VideoCapture(video_path) frames = [] while True: ret, frame = cap.read() if not ret: break frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) frames.append(frame) cap.release() frames = np.array(frames) video_tensor = torch.tensor(frames) video_tensor = video_tensor.permute(0, 3, 1, 2).float() / 255.0 video_tensor = video_tensor.unsqueeze(0).permute(0, 2, 1, 3, 4) return video_tensor def get_ttm_cog_pipeline(): """Lazy load CogVideoX TTM pipeline to save memory.""" global ttm_cog_pipeline if ttm_cog_pipeline is None and TTM_COG_AVAILABLE: logger.info("Loading TTM CogVideoX pipeline...") ttm_cog_pipeline = CogVideoXImageToVideoPipeline.from_pretrained( TTM_COG_MODEL_ID, torch_dtype=TTM_DTYPE, low_cpu_mem_usage=True, ) ttm_cog_pipeline.vae.enable_tiling() ttm_cog_pipeline.vae.enable_slicing() logger.info("TTM CogVideoX pipeline loaded successfully!") return ttm_cog_pipeline def get_ttm_wan_pipeline(): """Lazy load Wan TTM pipeline to save memory.""" global ttm_wan_pipeline if ttm_wan_pipeline is None and TTM_WAN_AVAILABLE: logger.info("Loading TTM Wan 2.2 pipeline...") ttm_wan_pipeline = WanImageToVideoPipeline.from_pretrained( TTM_WAN_MODEL_ID, torch_dtype=TTM_DTYPE, ) ttm_wan_pipeline.vae.enable_tiling() ttm_wan_pipeline.vae.enable_slicing() logger.info("TTM Wan 2.2 pipeline loaded successfully!") return ttm_wan_pipeline logger.info("Setting up thread pool and utility functions...") sys.stdout.flush() # Thread pool for delayed deletion thread_pool_executor = ThreadPoolExecutor(max_workers=2) def load_models(): """Load models lazily when GPU is available (inside @spaces.GPU decorated function).""" global vggt4track_model, tracker_model, MODELS_LOADED if MODELS_LOADED: logger.info("Models already loaded, skipping...") return logger.info("🚀 Starting model loading...") sys.stdout.flush() try: logger.info("Loading VGGT4Track model from 'Yuxihenry/SpatialTrackerV2_Front'...") sys.stdout.flush() vggt4track_model = VGGT4Track.from_pretrained("Yuxihenry/SpatialTrackerV2_Front") vggt4track_model.eval() logger.info("✅ VGGT4Track model loaded, moving to CUDA...") sys.stdout.flush() vggt4track_model = vggt4track_model.to("cuda") logger.info("✅ VGGT4Track model on CUDA") sys.stdout.flush() logger.info("Loading Predictor model from 'Yuxihenry/SpatialTrackerV2-Offline'...") sys.stdout.flush() tracker_model = Predictor.from_pretrained("Yuxihenry/SpatialTrackerV2-Offline") tracker_model.eval() logger.info("✅ Predictor model loaded") sys.stdout.flush() MODELS_LOADED = True logger.info("✅ All models loaded successfully!") sys.stdout.flush() except Exception as e: logger.error(f"❌ Failed to load models: {e}") import traceback traceback.print_exc() sys.stdout.flush() raise def delete_later(path: Union[str, os.PathLike], delay: int = 600): """Delete file or directory after specified delay""" def _delete(): try: if os.path.isfile(path): os.remove(path) elif os.path.isdir(path): shutil.rmtree(path) except Exception as e: logger.warning(f"Failed to delete {path}: {e}") def _wait_and_delete(): time.sleep(delay) _delete() thread_pool_executor.submit(_wait_and_delete) atexit.register(_delete) def create_user_temp_dir(): """Create a unique temporary directory for each user session""" session_id = str(uuid.uuid4())[:8] temp_dir = os.path.join("temp_local", f"session_{session_id}") os.makedirs(temp_dir, exist_ok=True) delete_later(temp_dir, delay=600) return temp_dir # Note: Models are loaded lazily inside @spaces.GPU decorated functions # This is required for HF Spaces ZeroGPU compatibility logger.info("Models will be loaded lazily when GPU is available") sys.stdout.flush() logger.info("Setting up Gradio static paths...") gr.set_static_paths(paths=[Path.cwd().absolute()/"_viz"]) logger.info("✅ Static paths configured") sys.stdout.flush() def generate_camera_trajectory(num_frames: int, movement_type: str, base_intrinsics: np.ndarray, scene_scale: float = 1.0) -> tuple: """ Generate camera extrinsics for different movement types. Returns: extrinsics: (T, 4, 4) camera-to-world matrices """ # Movement speed (adjust based on scene scale) speed = scene_scale * 0.02 extrinsics = np.zeros((num_frames, 4, 4), dtype=np.float32) for t in range(num_frames): # Start with identity matrix ext = np.eye(4, dtype=np.float32) progress = t / max(num_frames - 1, 1) if movement_type == "static": pass # Keep identity elif movement_type == "move_forward": # Move along -Z (forward in OpenGL convention) ext[2, 3] = -speed * t elif movement_type == "move_backward": ext[2, 3] = speed * t # Move along +Z elif movement_type == "move_left": ext[0, 3] = -speed * t # Move along -X elif movement_type == "move_right": ext[0, 3] = speed * t # Move along +X elif movement_type == "move_up": ext[1, 3] = -speed * t # Move along -Y (up in OpenGL) elif movement_type == "move_down": ext[1, 3] = speed * t # Move along +Y extrinsics[t] = ext return extrinsics def render_from_pointcloud(rgb_frames: np.ndarray, depth_frames: np.ndarray, intrinsics: np.ndarray, original_extrinsics: np.ndarray, new_extrinsics: np.ndarray, output_path: str, fps: int = 24, generate_ttm_inputs: bool = False) -> dict: """ Render video from point cloud with new camera trajectory. Args: rgb_frames: (T, H, W, 3) RGB frames depth_frames: (T, H, W) depth maps intrinsics: (T, 3, 3) camera intrinsics original_extrinsics: (T, 4, 4) original camera extrinsics (world-to-camera) new_extrinsics: (T, 4, 4) new camera extrinsics for rendering output_path: path to save rendered video fps: output video fps generate_ttm_inputs: if True, also generate motion_signal.mp4 and mask.mp4 for TTM Returns: dict with paths: {'rendered': path, 'motion_signal': path or None, 'mask': path or None} """ T, H, W, _ = rgb_frames.shape # Setup video writers fourcc = cv2.VideoWriter_fourcc(*'mp4v') out = cv2.VideoWriter(output_path, fourcc, fps, (W, H)) # TTM outputs: motion_signal (warped with NN inpainting) and mask (valid pixels before inpainting) motion_signal_path = None mask_path = None out_motion_signal = None out_mask = None if generate_ttm_inputs: base_dir = os.path.dirname(output_path) motion_signal_path = os.path.join(base_dir, "motion_signal.mp4") mask_path = os.path.join(base_dir, "mask.mp4") out_motion_signal = cv2.VideoWriter( motion_signal_path, fourcc, fps, (W, H)) out_mask = cv2.VideoWriter(mask_path, fourcc, fps, (W, H)) # Create meshgrid for pixel coordinates u, v = np.meshgrid(np.arange(W), np.arange(H)) ones = np.ones_like(u) for t in range(T): # Get current frame data rgb = rgb_frames[t] depth = depth_frames[t] K = intrinsics[t] # Original camera pose (camera-to-world) orig_c2w = np.linalg.inv(original_extrinsics[t]) # New camera pose (camera-to-world for the new viewpoint) # Apply the new extrinsics relative to the first frame if t == 0: base_c2w = orig_c2w.copy() # New camera is: base_c2w @ new_extrinsics[t] new_c2w = base_c2w @ new_extrinsics[t] new_w2c = np.linalg.inv(new_c2w) # Unproject pixels to 3D points K_inv = np.linalg.inv(K) # Pixel coordinates to normalized camera coordinates pixels = np.stack([u, v, ones], axis=-1).reshape(-1, 3) # (H*W, 3) rays_cam = (K_inv @ pixels.T).T # (H*W, 3) # Scale by depth to get 3D points in original camera frame depth_flat = depth.reshape(-1, 1) points_cam = rays_cam * depth_flat # (H*W, 3) # Transform to world coordinates points_world = (orig_c2w[:3, :3] @ points_cam.T).T + orig_c2w[:3, 3] # Transform to new camera coordinates points_new_cam = (new_w2c[:3, :3] @ points_world.T).T + new_w2c[:3, 3] # Project to new image points_proj = (K @ points_new_cam.T).T # Get pixel coordinates z = points_proj[:, 2:3] z = np.clip(z, 1e-6, None) # Avoid division by zero uv_new = points_proj[:, :2] / z # Create output image using forward warping with z-buffer rendered = np.zeros((H, W, 3), dtype=np.uint8) z_buffer = np.full((H, W), np.inf, dtype=np.float32) colors = rgb.reshape(-1, 3) depths_new = points_new_cam[:, 2] for i in range(len(uv_new)): uu, vv = int(round(uv_new[i, 0])), int(round(uv_new[i, 1])) if 0 <= uu < W and 0 <= vv < H and depths_new[i] > 0: if depths_new[i] < z_buffer[vv, uu]: z_buffer[vv, uu] = depths_new[i] rendered[vv, uu] = colors[i] # Create valid pixel mask BEFORE hole filling (for TTM) # Valid pixels are those that received projected colors valid_mask = (rendered.sum(axis=-1) > 0).astype(np.uint8) * 255 # Nearest-neighbor hole filling using dilation # This is the inpainting method described in TTM: "Missing regions are inpainted by nearest-neighbor color assignment" motion_signal_frame = rendered.copy() hole_mask = (motion_signal_frame.sum(axis=-1) == 0).astype(np.uint8) if hole_mask.sum() > 0: kernel = np.ones((3, 3), np.uint8) # Iteratively dilate to fill holes with nearest neighbor colors max_iterations = max(H, W) # Ensure all holes can be filled for _ in range(max_iterations): if hole_mask.sum() == 0: break dilated = cv2.dilate(motion_signal_frame, kernel, iterations=1) motion_signal_frame = np.where( hole_mask[:, :, None] > 0, dilated, motion_signal_frame) hole_mask = (motion_signal_frame.sum( axis=-1) == 0).astype(np.uint8) # Write TTM outputs if enabled if generate_ttm_inputs: # Motion signal: warped frame with NN inpainting motion_signal_bgr = cv2.cvtColor( motion_signal_frame, cv2.COLOR_RGB2BGR) out_motion_signal.write(motion_signal_bgr) # Mask: binary mask of valid (projected) pixels - white where valid, black where holes mask_frame = np.stack( [valid_mask, valid_mask, valid_mask], axis=-1) out_mask.write(mask_frame) # For the rendered output, use the same inpainted result rendered_bgr = cv2.cvtColor(motion_signal_frame, cv2.COLOR_RGB2BGR) out.write(rendered_bgr) out.release() if generate_ttm_inputs: out_motion_signal.release() out_mask.release() return { 'rendered': output_path, 'motion_signal': motion_signal_path, 'mask': mask_path } @spaces.GPU(duration=180) def run_spatial_tracker(video_tensor: torch.Tensor): """ GPU-intensive spatial tracking function. Args: video_tensor: Preprocessed video tensor (T, C, H, W) Returns: Dictionary containing tracking results """ global vggt4track_model, tracker_model logger.info("run_spatial_tracker: Starting GPU execution...") sys.stdout.flush() # Load models if not already loaded (lazy loading for HF Spaces) load_models() logger.info("run_spatial_tracker: Preprocessing video input...") sys.stdout.flush() # Run VGGT to get depth and camera poses video_input = preprocess_image(video_tensor)[None].cuda() logger.info("run_spatial_tracker: Running VGGT inference...") sys.stdout.flush() with torch.no_grad(): with torch.cuda.amp.autocast(dtype=torch.bfloat16): predictions = vggt4track_model(video_input / 255) extrinsic = predictions["poses_pred"] intrinsic = predictions["intrs"] depth_map = predictions["points_map"][..., 2] depth_conf = predictions["unc_metric"] logger.info("run_spatial_tracker: VGGT inference complete") sys.stdout.flush() depth_tensor = depth_map.squeeze().cpu().numpy() extrs = extrinsic.squeeze().cpu().numpy() intrs = intrinsic.squeeze().cpu().numpy() video_tensor_gpu = video_input.squeeze() unc_metric = depth_conf.squeeze().cpu().numpy() > 0.5 # Setup tracker logger.info("run_spatial_tracker: Setting up tracker...") sys.stdout.flush() tracker_model.spatrack.track_num = 512 tracker_model.to("cuda") # Get grid points for tracking frame_H, frame_W = video_tensor_gpu.shape[2:] grid_pts = get_points_on_a_grid(30, (frame_H, frame_W), device="cpu") query_xyt = torch.cat([torch.zeros_like(grid_pts[:, :, :1]), grid_pts], dim=2)[ 0].numpy() logger.info("run_spatial_tracker: Running 3D tracker...") sys.stdout.flush() # Run tracker with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16): ( c2w_traj, intrs_out, point_map, conf_depth, track3d_pred, track2d_pred, vis_pred, conf_pred, video_out ) = tracker_model.forward( video_tensor_gpu, depth=depth_tensor, intrs=intrs, extrs=extrs, queries=query_xyt, fps=1, full_point=False, iters_track=4, query_no_BA=True, fixed_cam=False, stage=1, unc_metric=unc_metric, support_frame=len(video_tensor_gpu)-1, replace_ratio=0.2 ) # Resize outputs for rendering max_size = 384 h, w = video_out.shape[2:] scale = min(max_size / h, max_size / w) if scale < 1: new_h, new_w = int(h * scale), int(w * scale) video_out = T.Resize((new_h, new_w))(video_out) point_map = T.Resize((new_h, new_w))(point_map) conf_depth = T.Resize((new_h, new_w))(conf_depth) intrs_out[:, :2, :] = intrs_out[:, :2, :] * scale logger.info("run_spatial_tracker: Moving results to CPU...") sys.stdout.flush() # Move results to CPU and return result = { 'video_out': video_out.cpu(), 'point_map': point_map.cpu(), 'conf_depth': conf_depth.cpu(), 'intrs_out': intrs_out.cpu(), 'c2w_traj': c2w_traj.cpu(), } logger.info("run_spatial_tracker: Complete!") sys.stdout.flush() return result def process_video(video_path: str, camera_movement: str, generate_ttm: bool = True, progress=gr.Progress()): """Main processing function Args: video_path: Path to input video camera_movement: Type of camera movement generate_ttm: If True, generate TTM-compatible outputs (motion_signal.mp4, mask.mp4, first_frame.png) progress: Gradio progress tracker """ if video_path is None: return None, None, None, None, "❌ Please upload a video first" progress(0, desc="Initializing...") # Create temp directory temp_dir = create_user_temp_dir() out_dir = os.path.join(temp_dir, "results") os.makedirs(out_dir, exist_ok=True) try: # Load video progress(0.1, desc="Loading video...") video_reader = decord.VideoReader(video_path) video_tensor = torch.from_numpy( video_reader.get_batch(range(len(video_reader))).asnumpy() ).permute(0, 3, 1, 2).float() # Subsample frames if too many fps_skip = max(1, len(video_tensor) // MAX_FRAMES) video_tensor = video_tensor[::fps_skip][:MAX_FRAMES] # Resize to have minimum side 336 h, w = video_tensor.shape[2:] scale = 336 / min(h, w) if scale < 1: new_h, new_w = int(h * scale) // 2 * 2, int(w * scale) // 2 * 2 video_tensor = T.Resize((new_h, new_w))(video_tensor) progress(0.2, desc="Estimating depth and camera poses...") # Run GPU-intensive spatial tracking progress(0.4, desc="Running 3D tracking...") tracking_results = run_spatial_tracker(video_tensor) progress(0.6, desc="Preparing point cloud...") # Extract results from tracking video_out = tracking_results['video_out'] point_map = tracking_results['point_map'] conf_depth = tracking_results['conf_depth'] intrs_out = tracking_results['intrs_out'] c2w_traj = tracking_results['c2w_traj'] # Get RGB frames and depth rgb_frames = rearrange( video_out.numpy(), "T C H W -> T H W C").astype(np.uint8) depth_frames = point_map[:, 2].numpy() depth_conf_np = conf_depth.numpy() # Mask out unreliable depth depth_frames[depth_conf_np < 0.5] = 0 # Get camera parameters intrs_np = intrs_out.numpy() extrs_np = torch.inverse(c2w_traj).numpy() # world-to-camera progress( 0.7, desc=f"Generating {camera_movement} camera trajectory...") # Calculate scene scale from depth valid_depth = depth_frames[depth_frames > 0] scene_scale = np.median(valid_depth) if len(valid_depth) > 0 else 1.0 # Generate new camera trajectory num_frames = len(rgb_frames) new_extrinsics = generate_camera_trajectory( num_frames, camera_movement, intrs_np, scene_scale ) progress(0.8, desc="Rendering video from new viewpoint...") # Render video (CPU-based, no GPU needed) output_video_path = os.path.join(out_dir, "rendered_video.mp4") render_results = render_from_pointcloud( rgb_frames, depth_frames, intrs_np, extrs_np, new_extrinsics, output_video_path, fps=OUTPUT_FPS, generate_ttm_inputs=generate_ttm ) # Save first frame for TTM first_frame_path = None motion_signal_path = None mask_path = None if generate_ttm: first_frame_path = os.path.join(out_dir, "first_frame.png") # Save original first frame (before warping) as PNG first_frame_rgb = rgb_frames[0] first_frame_bgr = cv2.cvtColor(first_frame_rgb, cv2.COLOR_RGB2BGR) cv2.imwrite(first_frame_path, first_frame_bgr) motion_signal_path = render_results['motion_signal'] mask_path = render_results['mask'] progress(1.0, desc="Done!") status_msg = f"✅ Video rendered successfully with '{camera_movement}' camera movement!" if generate_ttm: status_msg += "\n\n📁 **TTM outputs generated:**\n" status_msg += f"- `first_frame.png`: Input frame for TTM\n" status_msg += f"- `motion_signal.mp4`: Warped video with NN inpainting\n" status_msg += f"- `mask.mp4`: Valid pixel mask (white=valid, black=hole)" return render_results['rendered'], motion_signal_path, mask_path, first_frame_path, status_msg except Exception as e: logger.error(f"Error processing video: {e}") import traceback traceback.print_exc() return None, None, None, None, f"❌ Error: {str(e)}" # TTM CogVideoX Pipeline Helper Classes and Functions class CogVideoXTTMHelper: """Helper class for TTM-style video generation using CogVideoX pipeline.""" def __init__(self, pipeline): self.pipeline = pipeline self.vae = pipeline.vae self.transformer = pipeline.transformer self.scheduler = pipeline.scheduler self.vae_scale_factor_spatial = 2 ** ( len(self.vae.config.block_out_channels) - 1) self.vae_scale_factor_temporal = self.vae.config.temporal_compression_ratio self.vae_scaling_factor_image = self.vae.config.scaling_factor self.video_processor = pipeline.video_processor @torch.no_grad() def encode_frames(self, frames: torch.Tensor) -> torch.Tensor: """Encode video frames into latent space. Input shape (B, C, F, H, W), expected range [-1, 1].""" latents = self.vae.encode(frames)[0].sample() latents = latents * self.vae_scaling_factor_image # (B, C, F, H, W) -> (B, F, C, H, W) return latents.permute(0, 2, 1, 3, 4).contiguous() def convert_rgb_mask_to_latent_mask(self, mask: torch.Tensor) -> torch.Tensor: """Convert a per-frame mask [T, 1, H, W] to latent resolution [1, T_latent, 1, H', W'].""" k = self.vae_scale_factor_temporal mask0 = mask[0:1] mask1 = mask[1::k] sampled = torch.cat([mask0, mask1], dim=0) pooled = sampled.permute(1, 0, 2, 3).unsqueeze(0) s = self.vae_scale_factor_spatial H_latent = pooled.shape[-2] // s W_latent = pooled.shape[-1] // s pooled = F.interpolate(pooled, size=( pooled.shape[2], H_latent, W_latent), mode="nearest") latent_mask = pooled.permute(0, 2, 1, 3, 4) return latent_mask # TTM Wan Pipeline Helper Class class WanTTMHelper: """Helper class for TTM-style video generation using Wan pipeline.""" def __init__(self, pipeline): self.pipeline = pipeline self.vae = pipeline.vae self.transformer = pipeline.transformer self.scheduler = pipeline.scheduler self.vae_scale_factor_temporal = self.vae.config.scale_factor_temporal self.vae_scale_factor_spatial = self.vae.config.scale_factor_spatial self.video_processor = pipeline.video_processor def convert_rgb_mask_to_latent_mask(self, mask: torch.Tensor) -> torch.Tensor: """Convert a per-frame mask [T, 1, H, W] to latent resolution [1, T_latent, 1, H', W'].""" k = self.vae_scale_factor_temporal mask0 = mask[0:1] mask1 = mask[1::k] sampled = torch.cat([mask0, mask1], dim=0) pooled = sampled.permute(1, 0, 2, 3).unsqueeze(0) s = self.vae_scale_factor_spatial H_latent = pooled.shape[-2] // s W_latent = pooled.shape[-1] // s pooled = F.interpolate(pooled, size=( pooled.shape[2], H_latent, W_latent), mode="nearest") latent_mask = pooled.permute(0, 2, 1, 3, 4) return latent_mask def compute_hw_from_area(image_height: int, image_width: int, max_area: int, mod_value: int) -> tuple: """Compute (height, width) with proper aspect ratio and rounding.""" aspect_ratio = image_height / image_width height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value return int(height), int(width) @spaces.GPU(duration=300) def run_ttm_cog_generation( first_frame_path: str, motion_signal_path: str, mask_path: str, prompt: str, tweak_index: int = 4, tstrong_index: int = 9, num_frames: int = 49, num_inference_steps: int = 50, guidance_scale: float = 6.0, seed: int = 0, progress=gr.Progress() ): """ Run TTM-style video generation using CogVideoX pipeline. Uses the generated motion signal and mask to guide video generation. """ if not TTM_COG_AVAILABLE: return None, "❌ CogVideoX TTM is not available. Please install diffusers package." if first_frame_path is None or motion_signal_path is None or mask_path is None: return None, "❌ Please generate TTM inputs first (first_frame, motion_signal, mask)" progress(0, desc="Loading CogVideoX TTM pipeline...") try: # Get or load the pipeline pipe = get_ttm_cog_pipeline() if pipe is None: return None, "❌ Failed to load CogVideoX TTM pipeline" pipe = pipe.to("cuda") # Create helper ttm_helper = CogVideoXTTMHelper(pipe) progress(0.1, desc="Loading inputs...") # Load first frame image = load_image(first_frame_path) # Get dimensions height = pipe.transformer.config.sample_height * \ ttm_helper.vae_scale_factor_spatial width = pipe.transformer.config.sample_width * \ ttm_helper.vae_scale_factor_spatial device = "cuda" generator = torch.Generator(device=device).manual_seed(seed) progress(0.15, desc="Encoding prompt...") # Encode prompt do_classifier_free_guidance = guidance_scale > 1.0 prompt_embeds, negative_prompt_embeds = pipe.encode_prompt( prompt=prompt, negative_prompt="", do_classifier_free_guidance=do_classifier_free_guidance, num_videos_per_prompt=1, max_sequence_length=226, device=device, ) if do_classifier_free_guidance: prompt_embeds = torch.cat( [negative_prompt_embeds, prompt_embeds], dim=0) progress(0.2, desc="Preparing latents...") # Prepare timesteps pipe.scheduler.set_timesteps(num_inference_steps, device=device) timesteps = pipe.scheduler.timesteps # Prepare latents latent_frames = ( num_frames - 1) // ttm_helper.vae_scale_factor_temporal + 1 # Handle padding for CogVideoX 1.5 patch_size_t = pipe.transformer.config.patch_size_t additional_frames = 0 if patch_size_t is not None and latent_frames % patch_size_t != 0: additional_frames = patch_size_t - latent_frames % patch_size_t num_frames += additional_frames * ttm_helper.vae_scale_factor_temporal # Preprocess image image_tensor = ttm_helper.video_processor.preprocess(image, height=height, width=width).to( device, dtype=prompt_embeds.dtype ) latent_channels = pipe.transformer.config.in_channels // 2 latents, image_latents = pipe.prepare_latents( image_tensor, 1, # batch_size latent_channels, num_frames, height, width, prompt_embeds.dtype, device, generator, None, ) progress(0.3, desc="Loading motion signal and mask...") # Load motion signal video ref_vid = load_video_to_tensor(motion_signal_path).to(device=device) refB, refC, refT, refH, refW = ref_vid.shape ref_vid = F.interpolate( ref_vid.permute(0, 2, 1, 3, 4).reshape( refB*refT, refC, refH, refW), size=(height, width), mode="bicubic", align_corners=True, ).reshape(refB, refT, refC, height, width).permute(0, 2, 1, 3, 4) ref_vid = ttm_helper.video_processor.normalize( ref_vid.to(dtype=pipe.vae.dtype)) ref_latents = ttm_helper.encode_frames(ref_vid).float().detach() # Load mask video ref_mask = load_video_to_tensor(mask_path).to(device=device) mB, mC, mT, mH, mW = ref_mask.shape ref_mask = F.interpolate( ref_mask.permute(0, 2, 1, 3, 4).reshape(mB*mT, mC, mH, mW), size=(height, width), mode="nearest", ).reshape(mB, mT, mC, height, width).permute(0, 2, 1, 3, 4) ref_mask = ref_mask[0].permute(1, 0, 2, 3).contiguous() if len(ref_mask.shape) == 4: ref_mask = ref_mask.unsqueeze(0) ref_mask = ref_mask[0, :, :1].contiguous() ref_mask = (ref_mask > 0.5).float().max(dim=1, keepdim=True)[0] motion_mask = ttm_helper.convert_rgb_mask_to_latent_mask(ref_mask) background_mask = 1.0 - motion_mask progress(0.35, desc="Initializing TTM denoising...") # Initialize with noisy reference latents at tweak timestep if tweak_index >= 0: tweak = timesteps[tweak_index] fixed_noise = randn_tensor( ref_latents.shape, generator=generator, device=ref_latents.device, dtype=ref_latents.dtype, ) noisy_latents = pipe.scheduler.add_noise( ref_latents, fixed_noise, tweak.long()) latents = noisy_latents.to( dtype=latents.dtype, device=latents.device) else: fixed_noise = randn_tensor( ref_latents.shape, generator=generator, device=ref_latents.device, dtype=ref_latents.dtype, ) tweak_index = 0 # Prepare extra step kwargs extra_step_kwargs = pipe.prepare_extra_step_kwargs(generator, 0.0) # Create rotary embeddings if required image_rotary_emb = ( pipe._prepare_rotary_positional_embeddings( height, width, latents.size(1), device) if pipe.transformer.config.use_rotary_positional_embeddings else None ) # Create ofs embeddings if required ofs_emb = None if pipe.transformer.config.ofs_embed_dim is None else latents.new_full( (1,), fill_value=2.0) progress(0.4, desc="Running TTM denoising loop...") # Denoising loop total_steps = len(timesteps) - tweak_index old_pred_original_sample = None for i, t in enumerate(timesteps[tweak_index:]): step_progress = 0.4 + 0.5 * (i / total_steps) progress(step_progress, desc=f"Denoising step {i+1}/{total_steps}...") latent_model_input = torch.cat( [latents] * 2) if do_classifier_free_guidance else latents latent_model_input = pipe.scheduler.scale_model_input( latent_model_input, t) latent_image_input = torch.cat( [image_latents] * 2) if do_classifier_free_guidance else image_latents latent_model_input = torch.cat( [latent_model_input, latent_image_input], dim=2) timestep = t.expand(latent_model_input.shape[0]) # Predict noise noise_pred = pipe.transformer( hidden_states=latent_model_input, encoder_hidden_states=prompt_embeds, timestep=timestep, ofs=ofs_emb, image_rotary_emb=image_rotary_emb, return_dict=False, )[0] noise_pred = noise_pred.float() # Perform guidance if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * \ (noise_pred_text - noise_pred_uncond) # Compute previous noisy sample if not isinstance(pipe.scheduler, CogVideoXDPMScheduler): latents, old_pred_original_sample = pipe.scheduler.step( noise_pred, t, latents, **extra_step_kwargs, return_dict=False ) else: latents, old_pred_original_sample = pipe.scheduler.step( noise_pred, old_pred_original_sample, t, timesteps[i - 1] if i > 0 else None, latents, **extra_step_kwargs, return_dict=False, ) # TTM: In between tweak and tstrong, replace mask with noisy reference latents in_between_tweak_tstrong = (i + tweak_index) < tstrong_index if in_between_tweak_tstrong: if i + tweak_index + 1 < len(timesteps): prev_t = timesteps[i + tweak_index + 1] noisy_latents = pipe.scheduler.add_noise(ref_latents, fixed_noise, prev_t.long()).to( dtype=latents.dtype, device=latents.device ) latents = latents * background_mask + noisy_latents * motion_mask else: latents = latents * background_mask + ref_latents * motion_mask latents = latents.to(prompt_embeds.dtype) progress(0.9, desc="Decoding video...") # Decode latents latents = latents[:, additional_frames:] frames = pipe.decode_latents(latents) video = ttm_helper.video_processor.postprocess_video( video=frames, output_type="pil") progress(0.95, desc="Saving video...") # Save video temp_dir = create_user_temp_dir() output_path = os.path.join(temp_dir, "ttm_output.mp4") export_to_video(video[0], output_path, fps=8) progress(1.0, desc="Done!") return output_path, f"✅ CogVideoX TTM video generated successfully!\n\n**Parameters:**\n- Model: CogVideoX-5B\n- tweak_index: {tweak_index}\n- tstrong_index: {tstrong_index}\n- guidance_scale: {guidance_scale}\n- seed: {seed}" except Exception as e: logger.error(f"Error in CogVideoX TTM generation: {e}") import traceback traceback.print_exc() return None, f"❌ Error: {str(e)}" @spaces.GPU(duration=300) def run_ttm_wan_generation( first_frame_path: str, motion_signal_path: str, mask_path: str, prompt: str, negative_prompt: str = "", tweak_index: int = 3, tstrong_index: int = 7, num_frames: int = 81, num_inference_steps: int = 50, guidance_scale: float = 3.5, seed: int = 0, progress=gr.Progress() ): """ Run TTM-style video generation using Wan 2.2 pipeline. This is the recommended model for TTM as it produces higher-quality results. """ if not TTM_WAN_AVAILABLE: return None, "❌ Wan TTM is not available. Please install diffusers with Wan support." if first_frame_path is None or motion_signal_path is None or mask_path is None: return None, "❌ Please generate TTM inputs first (first_frame, motion_signal, mask)" progress(0, desc="Loading Wan 2.2 TTM pipeline...") try: # Get or load the pipeline pipe = get_ttm_wan_pipeline() if pipe is None: return None, "❌ Failed to load Wan TTM pipeline" pipe = pipe.to("cuda") # Create helper ttm_helper = WanTTMHelper(pipe) progress(0.1, desc="Loading inputs...") # Load first frame image = load_image(first_frame_path) # Get dimensions - compute based on image aspect ratio max_area = 480 * 832 mod_value = ttm_helper.vae_scale_factor_spatial * \ pipe.transformer.config.patch_size[1] height, width = compute_hw_from_area( image.height, image.width, max_area, mod_value) image = image.resize((width, height)) device = "cuda" gen_device = device if device.startswith("cuda") else "cpu" generator = torch.Generator(device=gen_device).manual_seed(seed) progress(0.15, desc="Encoding prompt...") # Encode prompt do_classifier_free_guidance = guidance_scale > 1.0 prompt_embeds, negative_prompt_embeds = pipe.encode_prompt( prompt=prompt, negative_prompt=negative_prompt if negative_prompt else None, do_classifier_free_guidance=do_classifier_free_guidance, num_videos_per_prompt=1, max_sequence_length=512, device=device, ) # Get transformer dtype transformer_dtype = pipe.transformer.dtype prompt_embeds = prompt_embeds.to(transformer_dtype) if negative_prompt_embeds is not None: negative_prompt_embeds = negative_prompt_embeds.to( transformer_dtype) # Encode image embedding if transformer supports it image_embeds = None if pipe.transformer.config.image_dim is not None: image_embeds = pipe.encode_image(image, device) image_embeds = image_embeds.repeat(1, 1, 1) image_embeds = image_embeds.to(transformer_dtype) progress(0.2, desc="Preparing latents...") # Prepare timesteps pipe.scheduler.set_timesteps(num_inference_steps, device=device) timesteps = pipe.scheduler.timesteps # Adjust num_frames to be valid for VAE if num_frames % ttm_helper.vae_scale_factor_temporal != 1: num_frames = num_frames // ttm_helper.vae_scale_factor_temporal * \ ttm_helper.vae_scale_factor_temporal + 1 num_frames = max(num_frames, 1) # Prepare latent variables num_channels_latents = pipe.vae.config.z_dim image_tensor = ttm_helper.video_processor.preprocess( image, height=height, width=width).to(device, dtype=torch.float32) latents_outputs = pipe.prepare_latents( image_tensor, 1, # batch_size num_channels_latents, height, width, num_frames, torch.float32, device, generator, None, None, # last_image ) if hasattr(pipe, 'config') and pipe.config.expand_timesteps: latents, condition, first_frame_mask = latents_outputs else: latents, condition = latents_outputs first_frame_mask = None progress(0.3, desc="Loading motion signal and mask...") # Load motion signal video ref_vid = load_video_to_tensor(motion_signal_path).to(device=device) refB, refC, refT, refH, refW = ref_vid.shape ref_vid = F.interpolate( ref_vid.permute(0, 2, 1, 3, 4).reshape( refB*refT, refC, refH, refW), size=(height, width), mode="bicubic", align_corners=True, ).reshape(refB, refT, refC, height, width).permute(0, 2, 1, 3, 4) ref_vid = ttm_helper.video_processor.normalize( ref_vid.to(dtype=pipe.vae.dtype)) ref_latents = retrieve_latents( pipe.vae.encode(ref_vid), sample_mode="argmax") # Normalize latents latents_mean = torch.tensor(pipe.vae.config.latents_mean).view( 1, pipe.vae.config.z_dim, 1, 1, 1).to(ref_latents.device, ref_latents.dtype) latents_std = 1.0 / torch.tensor(pipe.vae.config.latents_std).view( 1, pipe.vae.config.z_dim, 1, 1, 1).to(ref_latents.device, ref_latents.dtype) ref_latents = (ref_latents - latents_mean) * latents_std # Load mask video ref_mask = load_video_to_tensor(mask_path).to(device=device) mB, mC, mT, mH, mW = ref_mask.shape ref_mask = F.interpolate( ref_mask.permute(0, 2, 1, 3, 4).reshape(mB*mT, mC, mH, mW), size=(height, width), mode="nearest", ).reshape(mB, mT, mC, height, width).permute(0, 2, 1, 3, 4) mask_tc_hw = ref_mask[0].permute(1, 0, 2, 3).contiguous() # Align time dimension if mask_tc_hw.shape[0] > num_frames: mask_tc_hw = mask_tc_hw[:num_frames] elif mask_tc_hw.shape[0] < num_frames: return None, f"❌ num_frames ({num_frames}) > mask frames ({mask_tc_hw.shape[0]}). Please use more mask frames." # Reduce channels if needed if mask_tc_hw.shape[1] > 1: mask_t1_hw = (mask_tc_hw > 0.5).any(dim=1, keepdim=True).float() else: mask_t1_hw = (mask_tc_hw > 0.5).float() motion_mask = ttm_helper.convert_rgb_mask_to_latent_mask( mask_t1_hw).permute(0, 2, 1, 3, 4).contiguous() background_mask = 1.0 - motion_mask progress(0.35, desc="Initializing TTM denoising...") # Initialize with noisy reference latents at tweak timestep if tweak_index >= 0 and tweak_index < len(timesteps): tweak = timesteps[tweak_index] fixed_noise = randn_tensor( ref_latents.shape, generator=generator, device=ref_latents.device, dtype=ref_latents.dtype, ) tweak_t = torch.as_tensor( tweak, device=ref_latents.device, dtype=torch.long).view(1) noisy_latents = pipe.scheduler.add_noise( ref_latents, fixed_noise, tweak_t.long()) latents = noisy_latents.to( dtype=latents.dtype, device=latents.device) else: fixed_noise = randn_tensor( ref_latents.shape, generator=generator, device=ref_latents.device, dtype=ref_latents.dtype, ) tweak_index = 0 progress(0.4, desc="Running TTM denoising loop...") # Denoising loop total_steps = len(timesteps) - tweak_index for i, t in enumerate(timesteps[tweak_index:]): step_progress = 0.4 + 0.5 * (i / total_steps) progress(step_progress, desc=f"Denoising step {i+1}/{total_steps}...") # Prepare model input if first_frame_mask is not None: latent_model_input = (1 - first_frame_mask) * \ condition + first_frame_mask * latents latent_model_input = latent_model_input.to(transformer_dtype) temp_ts = (first_frame_mask[0][0][:, ::2, ::2] * t).flatten() timestep = temp_ts.unsqueeze(0).expand(latents.shape[0], -1) else: latent_model_input = torch.cat( [latents, condition], dim=1).to(transformer_dtype) timestep = t.expand(latents.shape[0]) # Predict noise (conditional) noise_pred = pipe.transformer( hidden_states=latent_model_input, timestep=timestep, encoder_hidden_states=prompt_embeds, encoder_hidden_states_image=image_embeds, return_dict=False, )[0] # CFG if do_classifier_free_guidance: noise_uncond = pipe.transformer( hidden_states=latent_model_input, timestep=timestep, encoder_hidden_states=negative_prompt_embeds, encoder_hidden_states_image=image_embeds, return_dict=False, )[0] noise_pred = noise_uncond + guidance_scale * \ (noise_pred - noise_uncond) # Scheduler step latents = pipe.scheduler.step( noise_pred, t, latents, return_dict=False)[0] # TTM: In between tweak and tstrong, replace mask with noisy reference latents in_between_tweak_tstrong = (i + tweak_index) < tstrong_index if in_between_tweak_tstrong: if i + tweak_index + 1 < len(timesteps): prev_t = timesteps[i + tweak_index + 1] prev_t = torch.as_tensor( prev_t, device=ref_latents.device, dtype=torch.long).view(1) noisy_latents = pipe.scheduler.add_noise(ref_latents, fixed_noise, prev_t.long()).to( dtype=latents.dtype, device=latents.device ) latents = latents * background_mask + noisy_latents * motion_mask else: latents = latents * background_mask + \ ref_latents.to(dtype=latents.dtype, device=latents.device) * motion_mask progress(0.9, desc="Decoding video...") # Apply first frame mask if used if first_frame_mask is not None: latents = (1 - first_frame_mask) * condition + \ first_frame_mask * latents # Decode latents latents = latents.to(pipe.vae.dtype) latents_mean = torch.tensor(pipe.vae.config.latents_mean).view( 1, pipe.vae.config.z_dim, 1, 1, 1).to(latents.device, latents.dtype) latents_std = 1.0 / torch.tensor(pipe.vae.config.latents_std).view( 1, pipe.vae.config.z_dim, 1, 1, 1).to(latents.device, latents.dtype) latents = latents / latents_std + latents_mean video = pipe.vae.decode(latents, return_dict=False)[0] video = ttm_helper.video_processor.postprocess_video( video, output_type="pil") progress(0.95, desc="Saving video...") # Save video temp_dir = create_user_temp_dir() output_path = os.path.join(temp_dir, "ttm_wan_output.mp4") export_to_video(video[0], output_path, fps=16) progress(1.0, desc="Done!") return output_path, f"✅ Wan 2.2 TTM video generated successfully!\n\n**Parameters:**\n- Model: Wan2.2-14B\n- tweak_index: {tweak_index}\n- tstrong_index: {tstrong_index}\n- guidance_scale: {guidance_scale}\n- seed: {seed}" except Exception as e: logger.error(f"Error in Wan TTM generation: {e}") import traceback traceback.print_exc() return None, f"❌ Error: {str(e)}" def run_ttm_generation( first_frame_path: str, motion_signal_path: str, mask_path: str, prompt: str, negative_prompt: str, model_choice: str, tweak_index: int, tstrong_index: int, num_frames: int, num_inference_steps: int, guidance_scale: float, seed: int, progress=gr.Progress() ): """ Router function that calls the appropriate TTM generation based on model choice. """ if "Wan" in model_choice: return run_ttm_wan_generation( first_frame_path=first_frame_path, motion_signal_path=motion_signal_path, mask_path=mask_path, prompt=prompt, negative_prompt=negative_prompt, tweak_index=tweak_index, tstrong_index=tstrong_index, num_frames=num_frames, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, seed=seed, progress=progress, ) else: return run_ttm_cog_generation( first_frame_path=first_frame_path, motion_signal_path=motion_signal_path, mask_path=mask_path, prompt=prompt, tweak_index=tweak_index, tstrong_index=tstrong_index, num_frames=num_frames, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, seed=seed, progress=progress, ) # Create Gradio interface logger.info("🎨 Creating Gradio interface...") sys.stdout.flush() with gr.Blocks( theme=gr.themes.Soft(), title="🎬 Video to Point Cloud Renderer", css=""" .gradio-container { max-width: 1400px !important; margin: auto !important; } """ ) as demo: gr.Markdown(""" # 🎬 Video to Point Cloud Renderer + TTM Video Generation Upload a video to generate a 3D point cloud, render it from a new camera perspective, and optionally run **Time-to-Move (TTM)** for motion-controlled video generation. **Workflow:** 1. **Step 1**: Upload a video and select camera movement → Generate motion signal & mask 2. **Step 2**: (Optional) Run TTM to generate a high-quality video with the motion signal **TTM (Time-to-Move)** uses dual-clock denoising to guide video generation using: - `first_frame.png`: Starting frame - `motion_signal.mp4`: Warped video showing desired motion - `mask.mp4`: Binary mask for motion regions """) # State to store paths for TTM first_frame_state = gr.State(None) motion_signal_state = gr.State(None) mask_state = gr.State(None) with gr.Tabs(): with gr.Tab("📥 Step 1: Generate Motion Signal"): with gr.Row(): with gr.Column(scale=1): gr.Markdown("### 📥 Input") video_input = gr.Video( label="Upload Video", format="mp4", height=300 ) camera_movement = gr.Dropdown( choices=CAMERA_MOVEMENTS, value="static", label="🎥 Camera Movement", info="Select how the camera should move in the rendered video" ) generate_ttm = gr.Checkbox( label="🎯 Generate TTM Inputs", value=True, info="Generate motion_signal.mp4 and mask.mp4 for Time-to-Move" ) generate_btn = gr.Button( "🚀 Generate Motion Signal", variant="primary", size="lg") with gr.Column(scale=1): gr.Markdown("### 📤 Rendered Output") output_video = gr.Video( label="Rendered Video", height=250 ) first_frame_output = gr.Image( label="First Frame (first_frame.png)", height=150 ) with gr.Row(): with gr.Column(scale=1): gr.Markdown("### 🎯 TTM: Motion Signal") motion_signal_output = gr.Video( label="Motion Signal Video (motion_signal.mp4)", height=250 ) with gr.Column(scale=1): gr.Markdown("### 🎭 TTM: Mask") mask_output = gr.Video( label="Mask Video (mask.mp4)", height=250 ) status_text = gr.Markdown("Ready to process...") with gr.Tab("🎬 Step 2: TTM Video Generation"): cog_available = "✅" if TTM_COG_AVAILABLE else "❌" wan_available = "✅" if TTM_WAN_AVAILABLE else "❌" gr.Markdown(f""" ### 🎬 Time-to-Move (TTM) Video Generation **Model Availability:** - {cog_available} CogVideoX-5B-I2V - {wan_available} Wan 2.2-14B (Recommended - higher quality) **TTM Parameters:** - **tweak_index**: When denoising starts *outside* the mask (lower = more dynamic background) - **tstrong_index**: When denoising starts *inside* the mask (higher = more constrained motion) **Recommended values:** - CogVideoX - Cut-and-Drag: `tweak_index=4`, `tstrong_index=9` - CogVideoX - Camera control: `tweak_index=3`, `tstrong_index=7` - **Wan 2.2 (Recommended)**: `tweak_index=3`, `tstrong_index=7` """) with gr.Row(): with gr.Column(scale=1): gr.Markdown("### ⚙️ TTM Settings") ttm_model_choice = gr.Dropdown( choices=TTM_MODELS if TTM_MODELS else ["No TTM models available"], value=TTM_MODELS[1] if TTM_WAN_AVAILABLE else (TTM_MODELS[0] if TTM_MODELS else None), label="Model", info="Wan 2.2 is recommended for higher quality" ) ttm_prompt = gr.Textbox( label="Prompt", placeholder="Describe the video content...", value="A high quality video, smooth motion, natural lighting", lines=2 ) ttm_negative_prompt = gr.Textbox( label="Negative Prompt (Wan only)", placeholder="Things to avoid in the video...", value="", lines=1, visible=True ) with gr.Row(): ttm_tweak_index = gr.Slider( minimum=0, maximum=20, value=3, step=1, label="tweak_index", info="When background denoising starts" ) ttm_tstrong_index = gr.Slider( minimum=0, maximum=30, value=7, step=1, label="tstrong_index", info="When mask region denoising starts" ) with gr.Row(): ttm_num_frames = gr.Slider( minimum=17, maximum=81, value=49, step=4, label="Number of Frames" ) ttm_guidance_scale = gr.Slider( minimum=1.0, maximum=15.0, value=3.5, step=0.5, label="Guidance Scale" ) with gr.Row(): ttm_num_steps = gr.Slider( minimum=20, maximum=100, value=50, step=5, label="Inference Steps" ) ttm_seed = gr.Number( value=0, label="Seed", precision=0 ) ttm_generate_btn = gr.Button( "🎬 Generate TTM Video", variant="primary", size="lg", interactive=TTM_AVAILABLE ) with gr.Column(scale=1): gr.Markdown("### 📤 TTM Output") ttm_output_video = gr.Video( label="TTM Generated Video", height=400 ) ttm_status_text = gr.Markdown( "Upload a video in Step 1 first, then run TTM here.") # TTM Input preview with gr.Accordion("📁 TTM Input Files (from Step 1)", open=False): with gr.Row(): ttm_preview_first_frame = gr.Image( label="First Frame", height=150 ) ttm_preview_motion = gr.Video( label="Motion Signal", height=150 ) ttm_preview_mask = gr.Video( label="Mask", height=150 ) # Helper function to update states and preview def process_and_update_states(video_path, camera_movement, generate_ttm_flag, progress=gr.Progress()): result = process_video(video_path, camera_movement, generate_ttm_flag, progress) output_vid, motion_sig, mask_vid, first_frame, status = result # Return all outputs including state updates and previews return ( output_vid, # output_video motion_sig, # motion_signal_output mask_vid, # mask_output first_frame, # first_frame_output status, # status_text first_frame, # first_frame_state motion_sig, # motion_signal_state mask_vid, # mask_state first_frame, # ttm_preview_first_frame motion_sig, # ttm_preview_motion mask_vid, # ttm_preview_mask ) # Event handlers generate_btn.click( fn=process_and_update_states, inputs=[video_input, camera_movement, generate_ttm], outputs=[ output_video, motion_signal_output, mask_output, first_frame_output, status_text, first_frame_state, motion_signal_state, mask_state, ttm_preview_first_frame, ttm_preview_motion, ttm_preview_mask ] ) # TTM generation event ttm_generate_btn.click( fn=run_ttm_generation, inputs=[ first_frame_state, motion_signal_state, mask_state, ttm_prompt, ttm_negative_prompt, ttm_model_choice, ttm_tweak_index, ttm_tstrong_index, ttm_num_frames, ttm_num_steps, ttm_guidance_scale, ttm_seed ], outputs=[ttm_output_video, ttm_status_text] ) # Examples gr.Markdown("### 📁 Examples") if os.path.exists("./examples"): example_videos = [f for f in os.listdir( "./examples") if f.endswith(".mp4")][:4] if example_videos: gr.Examples( examples=[[f"./examples/{v}", "move_forward", True] for v in example_videos], inputs=[video_input, camera_movement, generate_ttm], outputs=[ output_video, motion_signal_output, mask_output, first_frame_output, status_text, first_frame_state, motion_signal_state, mask_state, ttm_preview_first_frame, ttm_preview_motion, ttm_preview_mask ], fn=process_and_update_states, cache_examples=False ) # Launch logger.info("✅ Gradio interface created successfully!") logger.info("=" * 50) logger.info("Application ready to launch") logger.info("=" * 50) sys.stdout.flush() if __name__ == "__main__": logger.info("Starting Gradio server...") sys.stdout.flush() demo.launch(share=False)