TRELLIS2 / app.py
seawolf2357's picture
Update app.py
39ea68e verified
"""
TRELLIS.2 Text-to-3D Generator
๐ŸŽจ Comic Classic Theme
"""
import os
import shutil
import torch
import numpy as np
from PIL import Image
import tempfile
import uuid
from typing import Tuple
from datetime import datetime
import rerun as rr
try:
import rerun.blueprint as rrb
except ImportError:
rrb = None
from gradio_rerun import Rerun
import gradio as gr
from gradio_client import Client, handle_file
import spaces
from diffusers import ZImagePipeline
from trellis2.pipelines import Trellis2ImageTo3DPipeline
import o_voxel
os.environ["OPENCV_IO_ENABLE_OPENEXR"] = '1'
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
os.environ["ATTN_BACKEND"] = "flash_attn_3"
os.environ["FLEX_GEMM_AUTOTUNE_CACHE_PATH"] = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'autotune_cache.json')
os.environ["FLEX_GEMM_AUTOTUNER_VERBOSE"] = '1'
MAX_SEED = np.iinfo(np.int32).max
TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp')
print("Loading Z-Image-Turbo...")
try:
z_pipe = ZImagePipeline.from_pretrained("Tongyi-MAI/Z-Image-Turbo", torch_dtype=torch.bfloat16, low_cpu_mem_usage=False)
device = "cuda" if torch.cuda.is_available() else "cpu"
z_pipe.to(device)
except Exception as e:
print(f"Failed to load Z-Image-Turbo: {e}")
z_pipe = None
print("Loading TRELLIS.2...")
try:
trellis_pipeline = Trellis2ImageTo3DPipeline.from_pretrained('microsoft/TRELLIS.2-4B')
trellis_pipeline.rembg_model = None
trellis_pipeline.low_vram = False
trellis_pipeline.cuda()
except Exception as e:
print(f"Failed to load TRELLIS.2: {e}")
trellis_pipeline = None
rmbg_client = Client("briaai/BRIA-RMBG-2.0")
def start_session(req: gr.Request):
user_dir = os.path.join(TMP_DIR, str(req.session_hash))
os.makedirs(user_dir, exist_ok=True)
def end_session(req: gr.Request):
user_dir = os.path.join(TMP_DIR, str(req.session_hash))
if os.path.exists(user_dir):
shutil.rmtree(user_dir)
def remove_background(input: Image.Image) -> Image.Image:
with tempfile.NamedTemporaryFile(suffix='.png') as f:
input = input.convert('RGB')
input.save(f.name)
output = rmbg_client.predict(handle_file(f.name), api_name="/image")[0][0]
output = Image.open(output)
return output
def preprocess_image(input: Image.Image) -> Image.Image:
if input is None:
return None
has_alpha = False
if input.mode == 'RGBA':
alpha = np.array(input)[:, :, 3]
if not np.all(alpha == 255):
has_alpha = True
max_size = max(input.size)
scale = min(1, 1024 / max_size)
if scale < 1:
input = input.resize((int(input.width * scale), int(input.height * scale)), Image.Resampling.LANCZOS)
if has_alpha:
output = input
else:
output = remove_background(input)
output_np = np.array(output)
alpha = output_np[:, :, 3]
bbox = np.argwhere(alpha > 0.8 * 255)
if bbox.size == 0:
return output
bbox = np.min(bbox[:, 1]), np.min(bbox[:, 0]), np.max(bbox[:, 1]), np.max(bbox[:, 0])
center = (bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2
size = max(bbox[2] - bbox[0], bbox[3] - bbox[1])
size = int(size * 1)
bbox = center[0] - size // 2, center[1] - size // 2, center[0] + size // 2, center[1] + size // 2
output = output.crop(bbox)
output = np.array(output).astype(np.float32) / 255
output = output[:, :, :3] * output[:, :, 3:4]
output = Image.fromarray((output * 255).astype(np.uint8))
return output
def get_seed(randomize_seed: bool, seed: int) -> int:
return np.random.randint(0, MAX_SEED) if randomize_seed else seed
@spaces.GPU
def generate_txt2img(prompt, progress=gr.Progress(track_tqdm=True)):
if z_pipe is None:
raise gr.Error("Z-Image-Turbo model failed to load.")
if not prompt.strip():
raise gr.Error("Please enter a prompt.")
device = "cuda" if torch.cuda.is_available() else "cpu"
generator = torch.Generator(device).manual_seed(42)
progress(0.1, desc="Generating Image...")
try:
result = z_pipe(
prompt=prompt,
negative_prompt=None,
height=1024,
width=1024,
num_inference_steps=9,
guidance_scale=0.0,
generator=generator,
)
return result.images[0]
except Exception as e:
raise gr.Error(f"Generation failed: {str(e)}")
@spaces.GPU(duration=120)
def generate_3d(
image: Image.Image, seed: int, resolution: str,
decimation_target: int, texture_size: int,
ss_guidance_strength: float, ss_guidance_rescale: float,
ss_sampling_steps: int, ss_rescale_t: float,
shape_guidance: float, shape_rescale: float,
shape_steps: int, shape_rescale_t: float,
tex_guidance: float, tex_rescale: float,
tex_steps: int, tex_rescale_t: float,
req: gr.Request, progress=gr.Progress(track_tqdm=True)
) -> Tuple[str, str]:
if image is None:
raise gr.Error("Please provide an input image.")
if trellis_pipeline is None:
raise gr.Error("TRELLIS model is not loaded.")
user_dir = os.path.join(TMP_DIR, str(req.session_hash))
os.makedirs(user_dir, exist_ok=True)
progress(0.1, desc="Generating 3D...")
try:
outputs, latents = trellis_pipeline.run(
image, seed=seed, preprocess_image=False,
sparse_structure_sampler_params={"steps": ss_sampling_steps, "guidance_strength": ss_guidance_strength, "guidance_rescale": ss_guidance_rescale, "rescale_t": ss_rescale_t},
shape_slat_sampler_params={"steps": shape_steps, "guidance_strength": shape_guidance, "guidance_rescale": shape_rescale, "rescale_t": shape_rescale_t},
tex_slat_sampler_params={"steps": tex_steps, "guidance_strength": tex_guidance, "guidance_rescale": tex_rescale, "rescale_t": tex_rescale_t},
pipeline_type={"512": "512", "1024": "1024_cascade", "1536": "1536_cascade"}[resolution],
return_latent=True,
)
progress(0.7, desc="Processing Mesh...")
mesh = outputs[0]
mesh.simplify(1000000)
progress(0.9, desc="Exporting GLB...")
grid_size = latents[2]
try:
glb = o_voxel.postprocess.to_glb(
vertices=mesh.vertices, faces=mesh.faces, attr_volume=mesh.attrs,
coords=mesh.coords, attr_layout=trellis_pipeline.pbr_attr_layout,
grid_size=grid_size, aabb=[[-0.5, -0.5, -0.5], [0.5, 0.5, 0.5]],
decimation_target=decimation_target, texture_size=texture_size,
remesh=True, remesh_band=1, remesh_project=0, use_tqdm=True,
)
except RuntimeError:
glb = o_voxel.postprocess.to_glb(
vertices=mesh.vertices, faces=mesh.faces, attr_volume=mesh.attrs,
coords=mesh.coords, attr_layout=trellis_pipeline.pbr_attr_layout,
grid_size=grid_size, aabb=[[-0.5, -0.5, -0.5], [0.5, 0.5, 0.5]],
decimation_target=decimation_target, texture_size=texture_size,
remesh=False, remesh_band=1, remesh_project=0, use_tqdm=True,
)
timestamp = datetime.now().strftime("%Y-%m-%dT%H%M%S")
glb_path = os.path.join(user_dir, f'output_{timestamp}.glb')
glb.export(glb_path, extension_webp=False)
progress(0.95, desc="Creating Viewer...")
run_id = str(uuid.uuid4())
rec = rr.new_recording(application_id="TRELLIS-3D-Viewer", recording_id=run_id) if hasattr(rr, "new_recording") else rr.RecordingStream(application_id="TRELLIS-3D-Viewer", recording_id=run_id) if hasattr(rr, "RecordingStream") else rr
rec.log("world", rr.Clear(recursive=True), static=True)
rec.log("world", rr.ViewCoordinates.RIGHT_HAND_Y_UP, static=True)
rec.log("world/model", rr.Asset3D(path=glb_path), static=True)
if rrb is not None:
try:
blueprint = rrb.Blueprint(rrb.Spatial3DView(origin="/world", name="3D View"), collapse_panels=True)
rec.send_blueprint(blueprint)
except:
pass
rrd_path = os.path.join(user_dir, f'output_{timestamp}.rrd')
rec.save(rrd_path)
torch.cuda.empty_cache()
return rrd_path, glb_path
except Exception as e:
torch.cuda.empty_cache()
raise gr.Error(f"Generation failed: {str(e)}")
css = """
@import url('https://fonts.googleapis.com/css2?family=Bangers&family=Comic+Neue:wght@400;700&display=swap');
.gradio-container {
background-color: #FEF9C3 !important;
background-image: radial-gradient(#1F2937 1px, transparent 1px) !important;
background-size: 20px 20px !important;
min-height: 100vh !important;
font-family: 'Comic Neue', cursive, sans-serif !important;
}
.huggingface-space-header, #space-header, .space-header,
[class*="space-header"], .svelte-1ed2p3z, .space-header-badge,
.header-badge, [data-testid="space-header"], .svelte-kqij2n,
.svelte-1ax1toq, .embed-container > div:first-child {
display: none !important;
visibility: hidden !important;
height: 0 !important;
width: 0 !important;
overflow: hidden !important;
opacity: 0 !important;
pointer-events: none !important;
}
footer, .footer, .gradio-container footer, .built-with,
[class*="footer"], .gradio-footer, .main-footer,
div[class*="footer"], .show-api, .built-with-gradio,
a[href*="gradio.app"], a[href*="huggingface.co/spaces"] {
display: none !important;
visibility: hidden !important;
height: 0 !important;
padding: 0 !important;
margin: 0 !important;
}
#col-container { max-width: 960px; margin: 0 auto; }
.header-text h1 {
font-family: 'Bangers', cursive !important;
color: #1F2937 !important;
font-size: 3.5rem !important;
font-weight: 400 !important;
text-align: center !important;
margin-bottom: 0.5rem !important;
text-shadow: 4px 4px 0px #FACC15, 6px 6px 0px #1F2937 !important;
letter-spacing: 3px !important;
-webkit-text-stroke: 2px #1F2937 !important;
}
.subtitle {
text-align: center !important;
font-family: 'Comic Neue', cursive !important;
font-size: 1.2rem !important;
color: #1F2937 !important;
margin-bottom: 1.5rem !important;
font-weight: 700 !important;
}
.gr-panel, .gr-box, .gr-form, .block, .gr-group {
background: #FFFFFF !important;
border: 3px solid #1F2937 !important;
border-radius: 8px !important;
box-shadow: 6px 6px 0px #1F2937 !important;
transition: all 0.2s ease !important;
}
.gr-panel:hover, .block:hover {
transform: translate(-2px, -2px) !important;
box-shadow: 8px 8px 0px #1F2937 !important;
}
textarea, input[type="text"], input[type="number"] {
background: #FFFFFF !important;
border: 3px solid #1F2937 !important;
border-radius: 8px !important;
color: #1F2937 !important;
font-family: 'Comic Neue', cursive !important;
font-size: 1rem !important;
font-weight: 700 !important;
transition: all 0.2s ease !important;
}
textarea:focus, input[type="text"]:focus, input[type="number"]:focus {
border-color: #3B82F6 !important;
box-shadow: 4px 4px 0px #3B82F6 !important;
outline: none !important;
}
.gr-button-primary, button.primary, .gr-button.primary {
background: #3B82F6 !important;
border: 3px solid #1F2937 !important;
border-radius: 8px !important;
color: #FFFFFF !important;
font-family: 'Bangers', cursive !important;
font-weight: 400 !important;
font-size: 1.3rem !important;
letter-spacing: 2px !important;
padding: 14px 28px !important;
box-shadow: 5px 5px 0px #1F2937 !important;
transition: all 0.1s ease !important;
text-shadow: 1px 1px 0px #1F2937 !important;
}
.gr-button-primary:hover, button.primary:hover, .gr-button.primary:hover {
background: #2563EB !important;
transform: translate(-2px, -2px) !important;
box-shadow: 7px 7px 0px #1F2937 !important;
}
.gr-button-primary:active, button.primary:active, .gr-button.primary:active {
transform: translate(3px, 3px) !important;
box-shadow: 2px 2px 0px #1F2937 !important;
}
.gr-button-secondary, button.secondary {
background: #EF4444 !important;
border: 3px solid #1F2937 !important;
border-radius: 8px !important;
color: #FFFFFF !important;
font-family: 'Bangers', cursive !important;
font-weight: 400 !important;
font-size: 1.1rem !important;
letter-spacing: 1px !important;
box-shadow: 4px 4px 0px #1F2937 !important;
transition: all 0.1s ease !important;
text-shadow: 1px 1px 0px #1F2937 !important;
}
.gr-button-secondary:hover, button.secondary:hover {
background: #DC2626 !important;
transform: translate(-2px, -2px) !important;
box-shadow: 6px 6px 0px #1F2937 !important;
}
label, .gr-input-label, .gr-block-label {
color: #1F2937 !important;
font-family: 'Comic Neue', cursive !important;
font-weight: 700 !important;
font-size: 1rem !important;
}
.gr-file-upload {
border: 3px dashed #1F2937 !important;
border-radius: 8px !important;
background: #FEF9C3 !important;
}
.gr-file-upload:hover {
border-color: #3B82F6 !important;
background: #EFF6FF !important;
}
::-webkit-scrollbar { width: 12px; height: 12px; }
::-webkit-scrollbar-track { background: #FEF9C3; border: 2px solid #1F2937; }
::-webkit-scrollbar-thumb { background: #3B82F6; border: 2px solid #1F2937; border-radius: 0px; }
::-webkit-scrollbar-thumb:hover { background: #EF4444; }
::selection { background: #FACC15; color: #1F2937; }
a { color: #3B82F6 !important; text-decoration: none !important; font-weight: 700 !important; }
a:hover { color: #EF4444 !important; }
@media (max-width: 768px) {
.header-text h1 {
font-size: 2.2rem !important;
text-shadow: 3px 3px 0px #FACC15, 4px 4px 0px #1F2937 !important;
}
.gr-button-primary, button.primary { padding: 12px 20px !important; font-size: 1.1rem !important; }
.gr-panel, .block { box-shadow: 4px 4px 0px #1F2937 !important; }
}
@media (prefers-color-scheme: dark) {
.gradio-container { background-color: #FEF9C3 !important; }
}
"""
EXAMPLES_IMAGE = [f"example-images/A ({i}).webp" for i in range(1, 72)]
EXAMPLES_TEXT = [
"A Cat 3D model", "A realistic Cat 3D model", "A cartoon Cat 3D model",
"A low poly Cat 3D", "A cyberpunk Cat 3D", "A robotic Cat 3D",
"A Plane 3D model", "A fighter jet Plane 3D", "A vintage Plane 3D",
"A Car 3D model", "A sports Car 3D", "A cyberpunk Car 3D",
"A Shoe 3D model", "A sneaker Shoe 3D", "A boot Shoe 3D",
"A Chair 3D model", "A Table 3D model", "A Robot 3D model",
"A House 3D model", "A Spaceship 3D model", "A Motorcycle 3D model",
]
if __name__ == "__main__":
os.makedirs(TMP_DIR, exist_ok=True)
with gr.Blocks(title="TRELLIS.2 Text-to-3D", delete_cache=(300, 300)) as demo:
gr.HTML(f"<style>{css}</style>")
gr.HTML("""
<div style="text-align: center; margin: 20px 0 10px 0;">
<a href="https://www.humangen.ai" target="_blank" style="text-decoration: none;">
<img src="https://img.shields.io/static/v1?label=๐Ÿ  HOME&message=HUMANGEN.AI&color=0000ff&labelColor=ffcc00&style=for-the-badge" alt="HOME">
</a>
</div>
""")
gr.Markdown("# ๐ŸŽฎ TRELLIS.2 TEXT-TO-3D ๐ŸŽฎ", elem_classes="header-text")
gr.Markdown('<p class="subtitle">โœจ Generate 3D models from text or images! ๐Ÿš€</p>')
with gr.Row():
with gr.Column(scale=1, min_width=360):
with gr.Tabs():
with gr.Tab("๐Ÿ“ Text-to-3D"):
txt_prompt = gr.Textbox(label="๐Ÿ’ฌ Prompt", placeholder="e.g. A Cat 3D model", lines=2)
btn_gen_img = gr.Button("1๏ธโƒฃ Generate Image", variant="primary")
with gr.Tab("๐Ÿ–ผ๏ธ Image-to-3D"):
gr.Markdown("Upload an image directly.")
image_prompt = gr.Image(label="๐Ÿ“ท Input Image", format="png", image_mode="RGBA", type="pil", height=350)
with gr.Accordion(label="โš™๏ธ 3D Settings", open=False):
resolution = gr.Radio(["512", "1024", "1536"], label="Resolution", value="1024")
seed = gr.Slider(0, MAX_SEED, label="Seed", value=0, step=1)
randomize_seed = gr.Checkbox(label="๐ŸŽฒ Randomize Seed", value=True)
decimation_target = gr.Slider(50000, 500000, label="Target Faces", value=150000, step=10000)
texture_size = gr.Slider(512, 4096, label="Texture Size", value=1024, step=512)
btn_gen_3d = gr.Button("2๏ธโƒฃ Generate 3D", variant="primary")
with gr.Accordion(label="๐Ÿ”ง Advanced Sampler", open=False):
gr.Markdown("**Stage 1: Sparse Structure**")
ss_guidance_strength = gr.Slider(1.0, 10.0, value=7.5, label="Guidance")
ss_guidance_rescale = gr.Slider(0.0, 1.0, value=0.7, label="Rescale")
ss_sampling_steps = gr.Slider(1, 50, value=12, label="Steps")
ss_rescale_t = gr.Slider(1.0, 6.0, value=5.0, label="Rescale T")
gr.Markdown("**Stage 2: Shape**")
shape_guidance = gr.Slider(1.0, 10.0, value=7.5, label="Guidance")
shape_rescale = gr.Slider(0.0, 1.0, value=0.5, label="Rescale")
shape_steps = gr.Slider(1, 50, value=12, label="Steps")
shape_rescale_t = gr.Slider(1.0, 6.0, value=3.0, label="Rescale T")
gr.Markdown("**Stage 3: Material**")
tex_guidance = gr.Slider(1.0, 10.0, value=1.0, label="Guidance")
tex_rescale = gr.Slider(0.0, 1.0, value=0.0, label="Rescale")
tex_steps = gr.Slider(1, 50, value=12, label="Steps")
tex_rescale_t = gr.Slider(1.0, 6.0, value=3.0, label="Rescale T")
with gr.Column(scale=2):
gr.Markdown("### ๐ŸŽฏ 3D Output")
rerun_output = Rerun(label="3D Viewer", height=600)
download_btn = gr.DownloadButton(label="3๏ธโƒฃ Download GLB", variant="primary")
gr.Examples(examples=[[img] for img in EXAMPLES_IMAGE], inputs=[image_prompt], label="๐Ÿ–ผ๏ธ Image Examples")
gr.Examples(examples=[[txt] for txt in EXAMPLES_TEXT], inputs=[txt_prompt], label="๐Ÿ“ Text Examples")
demo.load(start_session)
demo.unload(end_session)
btn_gen_img.click(generate_txt2img, inputs=[txt_prompt], outputs=[image_prompt]).then(
preprocess_image, inputs=[image_prompt], outputs=[image_prompt]
)
image_prompt.upload(preprocess_image, inputs=[image_prompt], outputs=[image_prompt])
btn_gen_3d.click(get_seed, inputs=[randomize_seed, seed], outputs=[seed]).then(
generate_3d,
inputs=[
image_prompt, seed, resolution, decimation_target, texture_size,
ss_guidance_strength, ss_guidance_rescale, ss_sampling_steps, ss_rescale_t,
shape_guidance, shape_rescale, shape_steps, shape_rescale_t,
tex_guidance, tex_rescale, tex_steps, tex_rescale_t,
],
outputs=[rerun_output, download_btn],
)
demo.launch()