Spaces:
Running
on
Zero
Running
on
Zero
File size: 6,568 Bytes
939bf35 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 |
import importlib.util
from diffusers import AutoencoderKL
from transformers import (AutoProcessor, AutoTokenizer, CLIPImageProcessor,
CLIPTextModel, CLIPTokenizer,
CLIPVisionModelWithProjection, LlamaModel,
LlamaTokenizerFast, LlavaForConditionalGeneration,
Mistral3ForConditionalGeneration, PixtralProcessor,
Qwen3ForCausalLM, T5EncoderModel, T5Tokenizer,
T5TokenizerFast)
try:
from transformers import (Qwen2_5_VLConfig,
Qwen2_5_VLForConditionalGeneration,
Qwen2Tokenizer, Qwen2VLProcessor)
except:
Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer = None, None
Qwen2VLProcessor, Qwen2_5_VLConfig = None, None
print("Your transformers version is too old to load Qwen2_5_VLForConditionalGeneration and Qwen2Tokenizer. If you wish to use QwenImage, please upgrade your transformers package to the latest version.")
from .cogvideox_transformer3d import CogVideoXTransformer3DModel
from .cogvideox_vae import AutoencoderKLCogVideoX
from .fantasytalking_audio_encoder import FantasyTalkingAudioEncoder
from .fantasytalking_transformer3d import FantasyTalkingTransformer3DModel
from .flux2_image_processor import Flux2ImageProcessor
from .flux2_transformer2d import Flux2Transformer2DModel
from .flux2_transformer2d_control import Flux2ControlTransformer2DModel
from .flux2_vae import AutoencoderKLFlux2
from .flux_transformer2d import FluxTransformer2DModel
from .hunyuanvideo_transformer3d import HunyuanVideoTransformer3DModel
from .hunyuanvideo_vae import AutoencoderKLHunyuanVideo
from .qwenimage_transformer2d import QwenImageTransformer2DModel
from .qwenimage_vae import AutoencoderKLQwenImage
from .wan_audio_encoder import WanAudioEncoder
from .wan_image_encoder import CLIPModel
from .wan_text_encoder import WanT5EncoderModel
from .wan_transformer3d import (Wan2_2Transformer3DModel, WanRMSNorm,
WanSelfAttention, WanTransformer3DModel)
from .wan_transformer3d_animate import Wan2_2Transformer3DModel_Animate
from .wan_transformer3d_s2v import Wan2_2Transformer3DModel_S2V
from .wan_transformer3d_vace import VaceWanTransformer3DModel
from .wan_vae import AutoencoderKLWan, AutoencoderKLWan_
from .wan_vae3_8 import AutoencoderKLWan2_2_, AutoencoderKLWan3_8
from .z_image_transformer2d import ZImageTransformer2DModel
from .z_image_transformer2d_control import ZImageControlTransformer2DModel
# The pai_fuser is an internally developed acceleration package, which can be used on PAI.
if importlib.util.find_spec("paifuser") is not None:
# --------------------------------------------------------------- #
# The simple_wrapper is used to solve the problem
# about conflicts between cython and torch.compile
# --------------------------------------------------------------- #
def simple_wrapper(func):
def inner(*args, **kwargs):
return func(*args, **kwargs)
return inner
# --------------------------------------------------------------- #
# VAE Parallel Kernel
# --------------------------------------------------------------- #
from ..dist import parallel_magvit_vae
AutoencoderKLWan_.decode = simple_wrapper(parallel_magvit_vae(0.4, 8)(AutoencoderKLWan_.decode))
AutoencoderKLWan2_2_.decode = simple_wrapper(parallel_magvit_vae(0.4, 16)(AutoencoderKLWan2_2_.decode))
# --------------------------------------------------------------- #
# Sparse Attention
# --------------------------------------------------------------- #
import torch
from paifuser.ops import wan_sparse_attention_wrapper
WanSelfAttention.forward = simple_wrapper(wan_sparse_attention_wrapper()(WanSelfAttention.forward))
print("Import Sparse Attention")
WanTransformer3DModel.forward = simple_wrapper(WanTransformer3DModel.forward)
# --------------------------------------------------------------- #
# CFG Skip Turbo
# --------------------------------------------------------------- #
import os
if importlib.util.find_spec("paifuser.accelerator") is not None:
from paifuser.accelerator import (cfg_skip_turbo, disable_cfg_skip,
enable_cfg_skip, share_cfg_skip)
else:
from paifuser import (cfg_skip_turbo, disable_cfg_skip,
enable_cfg_skip, share_cfg_skip)
WanTransformer3DModel.enable_cfg_skip = enable_cfg_skip()(WanTransformer3DModel.enable_cfg_skip)
WanTransformer3DModel.disable_cfg_skip = disable_cfg_skip()(WanTransformer3DModel.disable_cfg_skip)
WanTransformer3DModel.share_cfg_skip = share_cfg_skip()(WanTransformer3DModel.share_cfg_skip)
QwenImageTransformer2DModel.enable_cfg_skip = enable_cfg_skip()(QwenImageTransformer2DModel.enable_cfg_skip)
QwenImageTransformer2DModel.disable_cfg_skip = disable_cfg_skip()(QwenImageTransformer2DModel.disable_cfg_skip)
print("Import CFG Skip Turbo")
# --------------------------------------------------------------- #
# RMS Norm Kernel
# --------------------------------------------------------------- #
from paifuser.ops import rms_norm_forward
WanRMSNorm.forward = rms_norm_forward
print("Import PAI RMS Fuse")
# --------------------------------------------------------------- #
# Fast Rope Kernel
# --------------------------------------------------------------- #
import types
import torch
from paifuser.ops import (ENABLE_KERNEL, fast_rope_apply_qk,
rope_apply_real_qk)
from . import wan_transformer3d
def deepcopy_function(f):
return types.FunctionType(f.__code__, f.__globals__, name=f.__name__, argdefs=f.__defaults__,closure=f.__closure__)
local_rope_apply_qk = deepcopy_function(wan_transformer3d.rope_apply_qk)
if ENABLE_KERNEL:
def adaptive_fast_rope_apply_qk(q, k, grid_sizes, freqs):
if torch.is_grad_enabled():
return local_rope_apply_qk(q, k, grid_sizes, freqs)
else:
return fast_rope_apply_qk(q, k, grid_sizes, freqs)
else:
def adaptive_fast_rope_apply_qk(q, k, grid_sizes, freqs):
return rope_apply_real_qk(q, k, grid_sizes, freqs)
wan_transformer3d.rope_apply_qk = adaptive_fast_rope_apply_qk
rope_apply_qk = adaptive_fast_rope_apply_qk
print("Import PAI Fast rope") |