File size: 6,568 Bytes
939bf35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
import importlib.util

from diffusers import AutoencoderKL
from transformers import (AutoProcessor, AutoTokenizer, CLIPImageProcessor,
                          CLIPTextModel, CLIPTokenizer,
                          CLIPVisionModelWithProjection, LlamaModel,
                          LlamaTokenizerFast, LlavaForConditionalGeneration,
                          Mistral3ForConditionalGeneration, PixtralProcessor,
                          Qwen3ForCausalLM, T5EncoderModel, T5Tokenizer,
                          T5TokenizerFast)

try:
    from transformers import (Qwen2_5_VLConfig,
                              Qwen2_5_VLForConditionalGeneration,
                              Qwen2Tokenizer, Qwen2VLProcessor)
except:
    Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer = None, None
    Qwen2VLProcessor, Qwen2_5_VLConfig = None, None
    print("Your transformers version is too old to load Qwen2_5_VLForConditionalGeneration and Qwen2Tokenizer. If you wish to use QwenImage, please upgrade your transformers package to the latest version.")

from .cogvideox_transformer3d import CogVideoXTransformer3DModel
from .cogvideox_vae import AutoencoderKLCogVideoX
from .fantasytalking_audio_encoder import FantasyTalkingAudioEncoder
from .fantasytalking_transformer3d import FantasyTalkingTransformer3DModel
from .flux2_image_processor import Flux2ImageProcessor
from .flux2_transformer2d import Flux2Transformer2DModel
from .flux2_transformer2d_control import Flux2ControlTransformer2DModel
from .flux2_vae import AutoencoderKLFlux2
from .flux_transformer2d import FluxTransformer2DModel
from .hunyuanvideo_transformer3d import HunyuanVideoTransformer3DModel
from .hunyuanvideo_vae import AutoencoderKLHunyuanVideo
from .qwenimage_transformer2d import QwenImageTransformer2DModel
from .qwenimage_vae import AutoencoderKLQwenImage
from .wan_audio_encoder import WanAudioEncoder
from .wan_image_encoder import CLIPModel
from .wan_text_encoder import WanT5EncoderModel
from .wan_transformer3d import (Wan2_2Transformer3DModel, WanRMSNorm,
                                WanSelfAttention, WanTransformer3DModel)
from .wan_transformer3d_animate import Wan2_2Transformer3DModel_Animate
from .wan_transformer3d_s2v import Wan2_2Transformer3DModel_S2V
from .wan_transformer3d_vace import VaceWanTransformer3DModel
from .wan_vae import AutoencoderKLWan, AutoencoderKLWan_
from .wan_vae3_8 import AutoencoderKLWan2_2_, AutoencoderKLWan3_8
from .z_image_transformer2d import ZImageTransformer2DModel
from .z_image_transformer2d_control import ZImageControlTransformer2DModel

# The pai_fuser is an internally developed acceleration package, which can be used on PAI.
if importlib.util.find_spec("paifuser") is not None:
    # --------------------------------------------------------------- #
    #   The simple_wrapper is used to solve the problem 
    #   about conflicts between cython and torch.compile
    # --------------------------------------------------------------- #
    def simple_wrapper(func):
        def inner(*args, **kwargs):
            return func(*args, **kwargs)
        return inner

    # --------------------------------------------------------------- #
    #   VAE Parallel Kernel
    # --------------------------------------------------------------- #
    from ..dist import parallel_magvit_vae
    AutoencoderKLWan_.decode = simple_wrapper(parallel_magvit_vae(0.4, 8)(AutoencoderKLWan_.decode))
    AutoencoderKLWan2_2_.decode = simple_wrapper(parallel_magvit_vae(0.4, 16)(AutoencoderKLWan2_2_.decode))

    # --------------------------------------------------------------- #
    #   Sparse Attention
    # --------------------------------------------------------------- #
    import torch
    from paifuser.ops import wan_sparse_attention_wrapper
    
    WanSelfAttention.forward = simple_wrapper(wan_sparse_attention_wrapper()(WanSelfAttention.forward))
    print("Import Sparse Attention")

    WanTransformer3DModel.forward = simple_wrapper(WanTransformer3DModel.forward)

    # --------------------------------------------------------------- #
    #   CFG Skip Turbo
    # --------------------------------------------------------------- #
    import os

    if importlib.util.find_spec("paifuser.accelerator") is not None:
        from paifuser.accelerator import (cfg_skip_turbo, disable_cfg_skip,
                                          enable_cfg_skip, share_cfg_skip)
    else:
        from paifuser import (cfg_skip_turbo, disable_cfg_skip,
                              enable_cfg_skip, share_cfg_skip)

    WanTransformer3DModel.enable_cfg_skip = enable_cfg_skip()(WanTransformer3DModel.enable_cfg_skip)
    WanTransformer3DModel.disable_cfg_skip = disable_cfg_skip()(WanTransformer3DModel.disable_cfg_skip)
    WanTransformer3DModel.share_cfg_skip = share_cfg_skip()(WanTransformer3DModel.share_cfg_skip)

    QwenImageTransformer2DModel.enable_cfg_skip = enable_cfg_skip()(QwenImageTransformer2DModel.enable_cfg_skip)
    QwenImageTransformer2DModel.disable_cfg_skip = disable_cfg_skip()(QwenImageTransformer2DModel.disable_cfg_skip)
    print("Import CFG Skip Turbo")

    # --------------------------------------------------------------- #
    #   RMS Norm Kernel
    # --------------------------------------------------------------- #
    from paifuser.ops import rms_norm_forward
    WanRMSNorm.forward = rms_norm_forward
    print("Import PAI RMS Fuse")

    # --------------------------------------------------------------- #
    #   Fast Rope Kernel
    # --------------------------------------------------------------- #
    import types

    import torch
    from paifuser.ops import (ENABLE_KERNEL, fast_rope_apply_qk,
                              rope_apply_real_qk)

    from . import wan_transformer3d

    def deepcopy_function(f):
        return types.FunctionType(f.__code__, f.__globals__, name=f.__name__, argdefs=f.__defaults__,closure=f.__closure__)

    local_rope_apply_qk = deepcopy_function(wan_transformer3d.rope_apply_qk)

    if ENABLE_KERNEL:
        def adaptive_fast_rope_apply_qk(q, k, grid_sizes, freqs):
            if torch.is_grad_enabled():
                return local_rope_apply_qk(q, k, grid_sizes, freqs)
            else:
                return fast_rope_apply_qk(q, k, grid_sizes, freqs)
    else:
        def adaptive_fast_rope_apply_qk(q, k, grid_sizes, freqs):
            return rope_apply_real_qk(q, k, grid_sizes, freqs)
            
    wan_transformer3d.rope_apply_qk = adaptive_fast_rope_apply_qk
    rope_apply_qk = adaptive_fast_rope_apply_qk
    print("Import PAI Fast rope")