Instructions to use Deci/DeciDiffusion-v2-0 with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Diffusers
How to use Deci/DeciDiffusion-v2-0 with Diffusers:
pip install -U diffusers transformers accelerate
import torch from diffusers import DiffusionPipeline # switch to "mps" for apple devices pipe = DiffusionPipeline.from_pretrained("Deci/DeciDiffusion-v2-0", dtype=torch.bfloat16, device_map="cuda") prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k" image = pipe(prompt).images[0] - Notebooks
- Google Colab
- Kaggle
- Local Apps
- Draw Things
- DiffusionBee
| import itertools | |
| from functools import partial | |
| from typing import Any, Dict, Tuple, Callable | |
| from typing import Union, Optional, List | |
| import numpy as np | |
| import torch | |
| from diffusers import DPMSolverMultistepScheduler | |
| from diffusers import StableDiffusionPipeline, AutoencoderKL | |
| from diffusers import Transformer2DModel, ModelMixin, ConfigMixin, SchedulerMixin | |
| from diffusers import UNet2DConditionModel | |
| from diffusers.configuration_utils import register_to_config | |
| from diffusers.models.attention import BasicTransformerBlock | |
| from diffusers.models.resnet import ResnetBlock2D, Downsample2D, Upsample2D | |
| from diffusers.models.transformer_2d import Transformer2DModelOutput | |
| from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker, StableDiffusionPipelineOutput | |
| from diffusers.schedulers import KarrasDiffusionSchedulers | |
| from diffusers.utils import replace_example_docstring | |
| from torch import nn | |
| from transformers import CLIPTextModel, CLIPTokenizer, CLIPImageProcessor | |
| def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): | |
| """ | |
| Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and | |
| Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 | |
| """ | |
| std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) | |
| std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) | |
| # rescale the results from guidance (fixes overexposure) | |
| noise_pred_rescaled = noise_cfg * (std_text / std_cfg) | |
| # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images | |
| noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg | |
| return noise_cfg | |
| def custom_sort_order(obj): | |
| """ | |
| Key function for sorting order of execution in forward methods | |
| """ | |
| return {ResnetBlock2D: 0, Transformer2DModel: 1, FlexibleTransformer2DModel: 1}.get(obj.__class__) | |
| def squeeze_to_len_n_starting_from_index_i(n, i, timestep_spacing): | |
| """ | |
| :param timestep_spacing: the timestep_spacing array we want to squeeze | |
| :param n: the size of the squeezed array | |
| :param i: the index we start squeezing from | |
| :return: squeezed timestep_spacing | |
| Example: | |
| timesteps = np.array([967, 907, 846, 786, 725, 665, 604, 544, 484, 423, 363, 302, 242, 181, 121, 60]) (len=16) | |
| n = 10, i = 6 | |
| Expected: | |
| [967, 907, 846, 786, 725, 665, 4k, 3k, 2k, k], and if we define 665=5k => k = 133 | |
| """ | |
| assert i < n | |
| squeezed = np.flip(np.arange(n)) + 1 # [n, n-1, ..., 2, 1] | |
| squeezed[:i] = timestep_spacing[:i] | |
| k = squeezed[i - 1] // (n - i + 1) | |
| squeezed[i:] *= k | |
| return squeezed | |
| PREDEFINED_TIMESTEP_SQUEEZERS = { | |
| # Tested with DPM 16-steps (reduced 16 -> 10 or 11 steps) | |
| "10,6": partial(squeeze_to_len_n_starting_from_index_i, 10, 6), | |
| "11,7": partial(squeeze_to_len_n_starting_from_index_i, 11, 7), | |
| } | |
| FlexibleUnetConfigurations = { | |
| # General parameters for all blocks | |
| "sample_size": 64, | |
| "temb_dim": 320 * 4, | |
| "resnet_eps": 1e-5, | |
| "resnet_act_fn": "silu", | |
| "num_attention_heads": 8, | |
| "cross_attention_dim": 768, | |
| # Controls modules execute order in unet's forward | |
| "mix_block_in_forward": True, | |
| # Down blocks parameters | |
| "down_blocks_in_channels": [320, 320, 640], | |
| "down_blocks_out_channels": [320, 640, 1280], | |
| "down_blocks_num_attentions": [0, 1, 3], | |
| "down_blocks_num_resnets": [2, 2, 1], | |
| "add_downsample": [True, True, False], | |
| # Middle block parameters | |
| "add_upsample_mid_block": None, | |
| "mid_num_resnets": 0, | |
| "mid_num_attentions": 0, | |
| # Up block parameters | |
| "prev_output_channels": [1280, 1280, 640], | |
| "up_blocks_num_attentions": [5, 3, 0], | |
| "up_blocks_num_resnets": [2, 3, 3], | |
| "add_upsample": [True, True, False], | |
| } | |
| class SqueezedDPMSolverMultistepScheduler(DPMSolverMultistepScheduler, SchedulerMixin): | |
| """ | |
| This is a copy-paste from Diffuser's `DPMSolverMultistepScheduler`, with minor differences: | |
| * Defaults are modified to accommodate DeciDiffusion | |
| * It supports a squeezer to squeeze the number of inference steps to a smaller number | |
| //!\\ IMPORTANT: the actual number of inference steps is deduced by the squeezer, and not the pipeline! | |
| """ | |
| def __init__( | |
| self, | |
| num_train_timesteps: int = 1000, | |
| beta_start: float = 0.0001, | |
| beta_end: float = 0.02, | |
| beta_schedule: str = "squaredcos_cap_v2", # NOTE THIS DEFAULT VALUE | |
| trained_betas: Optional[Union[np.ndarray, List[float]]] = None, | |
| solver_order: int = 2, | |
| prediction_type: str = "v_prediction", # NOTE THIS DEFAULT VALUE | |
| thresholding: bool = False, | |
| dynamic_thresholding_ratio: float = 0.995, | |
| sample_max_value: float = 1.0, | |
| algorithm_type: str = "dpmsolver++", | |
| solver_type: str = "heun", # NOTE THIS DEFAULT VALUE | |
| lower_order_final: bool = True, | |
| use_karras_sigmas: Optional[bool] = False, | |
| lambda_min_clipped: float = -7.5, # NOTE THIS DEFAULT VALUE | |
| variance_type: Optional[str] = None, | |
| timestep_spacing: str = "linspace", | |
| steps_offset: int = 1, | |
| squeeze_mode: Optional[str] = None, # NOTE THIS ADDITION. Supports keys from `PREDEFINED_TIMESTEP_SQUEEZERS` defined above | |
| ): | |
| self._squeezer = PREDEFINED_TIMESTEP_SQUEEZERS.get(squeeze_mode) | |
| if use_karras_sigmas: | |
| raise NotImplementedError("Squeezing isn't tested with `use_karras_sigmas`. Please provide `use_karras_sigmas=False`") | |
| super().__init__( | |
| num_train_timesteps=num_train_timesteps, | |
| beta_start=beta_start, | |
| beta_end=beta_end, | |
| beta_schedule=beta_schedule, | |
| trained_betas=trained_betas, | |
| solver_order=solver_order, | |
| prediction_type=prediction_type, | |
| thresholding=thresholding, | |
| dynamic_thresholding_ratio=dynamic_thresholding_ratio, | |
| sample_max_value=sample_max_value, | |
| algorithm_type=algorithm_type, | |
| solver_type=solver_type, | |
| lower_order_final=lower_order_final, | |
| use_karras_sigmas=False, | |
| lambda_min_clipped=lambda_min_clipped, | |
| variance_type=variance_type, | |
| timestep_spacing=timestep_spacing, | |
| steps_offset=steps_offset, | |
| ) | |
| def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torch.device] = None): | |
| """ | |
| Sets the discrete timesteps used for the diffusion chain (to be run before inference). | |
| Args: | |
| num_inference_steps (`int`): | |
| The number of diffusion steps used when generating samples with a pre-trained model. | |
| device (`str` or `torch.device`, *optional*): | |
| The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. | |
| """ | |
| super().set_timesteps(num_inference_steps=num_inference_steps, device=device) | |
| if self._squeezer is not None: | |
| timesteps = self._squeezer(self.timesteps.cpu()) | |
| sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) | |
| sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) | |
| sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5 | |
| sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32) | |
| self.sigmas = torch.from_numpy(sigmas) | |
| self.timesteps = torch.from_numpy(timesteps).to(device=device, dtype=torch.int64) | |
| self.num_inference_steps = len(timesteps) | |
| class FlexibleIdentityBlock(nn.Module): | |
| def forward( | |
| self, | |
| hidden_states: torch.FloatTensor, | |
| temb: Optional[torch.FloatTensor] = None, | |
| encoder_hidden_states: Optional[torch.FloatTensor] = None, | |
| attention_mask: Optional[torch.FloatTensor] = None, | |
| cross_attention_kwargs: Optional[Dict[str, Any]] = None, | |
| encoder_attention_mask: Optional[torch.FloatTensor] = None, | |
| ): | |
| return hidden_states | |
| class FlexibleUNet2DConditionModel(UNet2DConditionModel, ModelMixin): | |
| configurations = FlexibleUnetConfigurations | |
| def __init__(self): | |
| super().__init__( | |
| sample_size=self.configurations.get("sample_size", FlexibleUnetConfigurations["sample_size"]), | |
| cross_attention_dim=self.configurations.get("cross_attention_dim", FlexibleUnetConfigurations["cross_attention_dim"]), | |
| ) | |
| num_attention_heads = self.configurations.get("num_attention_heads") | |
| cross_attention_dim = self.configurations.get("cross_attention_dim") | |
| mix_block_in_forward = self.configurations.get("mix_block_in_forward") | |
| resnet_act_fn = self.configurations.get("resnet_act_fn") | |
| resnet_eps = self.configurations.get("resnet_eps") | |
| temb_dim = self.configurations.get("temb_dim") | |
| ############### | |
| # Down blocks # | |
| ############### | |
| down_blocks_num_attentions = self.configurations.get("down_blocks_num_attentions") | |
| down_blocks_out_channels = self.configurations.get("down_blocks_out_channels") | |
| down_blocks_in_channels = self.configurations.get("down_blocks_in_channels") | |
| down_blocks_num_resnets = self.configurations.get("down_blocks_num_resnets") | |
| add_downsample = self.configurations.get("add_downsample") | |
| self.down_blocks = nn.ModuleList() | |
| for i, (in_c, out_c, n_res, n_att, add_down) in enumerate( | |
| zip(down_blocks_in_channels, down_blocks_out_channels, down_blocks_num_resnets, down_blocks_num_attentions, add_downsample) | |
| ): | |
| last_block = i == len(down_blocks_in_channels) - 1 | |
| self.down_blocks.append( | |
| FlexibleCrossAttnDownBlock2D( | |
| in_channels=in_c, | |
| out_channels=out_c, | |
| temb_channels=temb_dim, | |
| num_resnets=n_res, | |
| num_attentions=n_att, | |
| resnet_eps=resnet_eps, | |
| resnet_act_fn=resnet_act_fn, | |
| num_attention_heads=num_attention_heads, | |
| cross_attention_dim=cross_attention_dim, | |
| add_downsample=add_down, | |
| last_block=last_block, | |
| mix_block_in_forward=mix_block_in_forward, | |
| ) | |
| ) | |
| ############### | |
| # Mid blocks # | |
| ############### | |
| mid_block_add_upsample = self.configurations.get("add_upsample_mid_block") | |
| mid_num_attentions = self.configurations.get("mid_num_attentions") | |
| mid_num_resnets = self.configurations.get("mid_num_resnets") | |
| if mid_num_resnets == mid_num_attentions == 0: | |
| self.mid_block = FlexibleIdentityBlock() | |
| else: | |
| self.mid_block = FlexibleUNetMidBlock2DCrossAttn( | |
| in_channels=down_blocks_out_channels[-1], | |
| temb_channels=temb_dim, | |
| resnet_act_fn=resnet_act_fn, | |
| resnet_eps=resnet_eps, | |
| cross_attention_dim=cross_attention_dim, | |
| num_attention_heads=num_attention_heads, | |
| num_resnets=mid_num_resnets, | |
| num_attentions=mid_num_attentions, | |
| mix_block_in_forward=mix_block_in_forward, | |
| add_upsample=mid_block_add_upsample, | |
| ) | |
| ############### | |
| # Up blocks # | |
| ############### | |
| up_blocks_num_attentions = self.configurations.get("up_blocks_num_attentions") | |
| up_blocks_num_resnets = self.configurations.get("up_blocks_num_resnets") | |
| prev_output_channels = self.configurations.get("prev_output_channels") | |
| up_upsample = self.configurations.get("add_upsample") | |
| self.up_blocks = nn.ModuleList() | |
| for in_c, out_c, prev_out, n_res, n_att, add_up in zip( | |
| reversed(down_blocks_in_channels), | |
| reversed(down_blocks_out_channels), | |
| prev_output_channels, | |
| up_blocks_num_resnets, | |
| up_blocks_num_attentions, | |
| up_upsample, | |
| ): | |
| self.up_blocks.append( | |
| FlexibleCrossAttnUpBlock2D( | |
| in_channels=in_c, | |
| out_channels=out_c, | |
| prev_output_channel=prev_out, | |
| temb_channels=temb_dim, | |
| num_resnets=n_res, | |
| num_attentions=n_att, | |
| resnet_eps=resnet_eps, | |
| resnet_act_fn=resnet_act_fn, | |
| num_attention_heads=num_attention_heads, | |
| cross_attention_dim=cross_attention_dim, | |
| add_upsample=add_up, | |
| mix_block_in_forward=mix_block_in_forward, | |
| ) | |
| ) | |
| class FlexibleCrossAttnDownBlock2D(nn.Module): | |
| def __init__( | |
| self, | |
| in_channels: int, | |
| out_channels: int, | |
| temb_channels: int, | |
| dropout: float = 0.0, | |
| num_resnets: int = 1, | |
| num_attentions: int = 1, | |
| transformer_layers_per_block: int = 1, | |
| resnet_eps: float = 1e-6, | |
| resnet_time_scale_shift: str = "default", | |
| resnet_act_fn: str = "swish", | |
| resnet_groups: int = 32, | |
| resnet_pre_norm: bool = True, | |
| num_attention_heads: int = 1, | |
| cross_attention_dim: int = 1280, | |
| output_scale_factor: float = 1.0, | |
| downsample_padding: int = 1, | |
| add_downsample: bool = True, | |
| use_linear_projection: bool = False, | |
| only_cross_attention: bool = False, | |
| upcast_attention: bool = False, | |
| last_block: bool = False, | |
| mix_block_in_forward: bool = True, | |
| ): | |
| super().__init__() | |
| self.last_block = last_block | |
| self.mix_block_in_forward = mix_block_in_forward | |
| self.has_cross_attention = True | |
| self.num_attention_heads = num_attention_heads | |
| modules = [] | |
| add_resnets = [True] * num_resnets | |
| add_cross_attentions = [True] * num_attentions | |
| for i, (add_resnet, add_cross_attention) in enumerate(itertools.zip_longest(add_resnets, add_cross_attentions, fillvalue=False)): | |
| in_channels = in_channels if i == 0 else out_channels | |
| if add_resnet: | |
| modules.append( | |
| ResnetBlock2D( | |
| in_channels=in_channels, | |
| out_channels=out_channels, | |
| temb_channels=temb_channels, | |
| eps=resnet_eps, | |
| groups=resnet_groups, | |
| dropout=dropout, | |
| time_embedding_norm=resnet_time_scale_shift, | |
| non_linearity=resnet_act_fn, | |
| output_scale_factor=output_scale_factor, | |
| pre_norm=resnet_pre_norm, | |
| ) | |
| ) | |
| if add_cross_attention: | |
| modules.append( | |
| FlexibleTransformer2DModel( | |
| num_attention_heads=num_attention_heads, | |
| attention_head_dim=out_channels // num_attention_heads, | |
| in_channels=out_channels, | |
| num_layers=transformer_layers_per_block, | |
| cross_attention_dim=cross_attention_dim, | |
| norm_num_groups=resnet_groups, | |
| use_linear_projection=use_linear_projection, | |
| only_cross_attention=only_cross_attention, | |
| upcast_attention=upcast_attention, | |
| ) | |
| ) | |
| if not mix_block_in_forward: | |
| modules = sorted(modules, key=custom_sort_order) | |
| self.modules_list = nn.ModuleList(modules) | |
| if add_downsample: | |
| self.downsamplers = nn.ModuleList([Downsample2D(out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op")]) | |
| else: | |
| self.downsamplers = None | |
| self.gradient_checkpointing = False | |
| def forward( | |
| self, | |
| hidden_states: torch.FloatTensor, | |
| temb: Optional[torch.FloatTensor] = None, | |
| encoder_hidden_states: Optional[torch.FloatTensor] = None, | |
| attention_mask: Optional[torch.FloatTensor] = None, | |
| cross_attention_kwargs: Optional[Dict[str, Any]] = None, | |
| encoder_attention_mask: Optional[torch.FloatTensor] = None, | |
| ): | |
| output_states = () | |
| for module in self.modules_list: | |
| if isinstance(module, ResnetBlock2D): | |
| hidden_states = module(hidden_states, temb) | |
| elif isinstance(module, (FlexibleTransformer2DModel, Transformer2DModel)): | |
| hidden_states = module( | |
| hidden_states, | |
| encoder_hidden_states=encoder_hidden_states, | |
| cross_attention_kwargs=cross_attention_kwargs, | |
| attention_mask=attention_mask, | |
| encoder_attention_mask=encoder_attention_mask, | |
| return_dict=False, | |
| )[0] | |
| else: | |
| raise ValueError(f"Got an unexpected module in modules list! {type(module)}") | |
| if isinstance(module, ResnetBlock2D): | |
| output_states = output_states + (hidden_states,) | |
| if self.downsamplers is not None: | |
| for downsampler in self.downsamplers: | |
| hidden_states = downsampler(hidden_states) | |
| if not self.last_block: | |
| output_states = output_states + (hidden_states,) | |
| return hidden_states, output_states | |
| class FlexibleCrossAttnUpBlock2D(nn.Module): | |
| def __init__( | |
| self, | |
| in_channels: int, | |
| out_channels: int, | |
| prev_output_channel: int, | |
| temb_channels: int, | |
| dropout: float = 0.0, | |
| num_resnets: int = 1, | |
| num_attentions: int = 1, | |
| transformer_layers_per_block: int = 1, | |
| resnet_eps: float = 1e-6, | |
| resnet_time_scale_shift: str = "default", | |
| resnet_act_fn: str = "swish", | |
| resnet_groups: int = 32, | |
| resnet_pre_norm: bool = True, | |
| num_attention_heads: int = 1, | |
| cross_attention_dim: int = 1280, | |
| output_scale_factor: float = 1.0, | |
| add_upsample: bool = True, | |
| use_linear_projection: bool = False, | |
| only_cross_attention: bool = False, | |
| upcast_attention: bool = False, | |
| mix_block_in_forward: bool = True, | |
| ): | |
| super().__init__() | |
| modules = [] | |
| # WARNING: This parameter is filled with number of resnets and used within StableDiffusionPipeline | |
| self.resnets = [] | |
| self.has_cross_attention = True | |
| self.num_attention_heads = num_attention_heads | |
| add_resnets = [True] * num_resnets | |
| add_cross_attentions = [True] * num_attentions | |
| for i, (add_resnet, add_cross_attention) in enumerate(itertools.zip_longest(add_resnets, add_cross_attentions, fillvalue=False)): | |
| res_skip_channels = in_channels if (i == len(add_resnets) - 1) else out_channels | |
| resnet_in_channels = prev_output_channel if i == 0 else out_channels | |
| if add_resnet: | |
| self.resnets += [True] | |
| modules.append( | |
| ResnetBlock2D( | |
| in_channels=resnet_in_channels + res_skip_channels, | |
| out_channels=out_channels, | |
| temb_channels=temb_channels, | |
| eps=resnet_eps, | |
| groups=resnet_groups, | |
| dropout=dropout, | |
| time_embedding_norm=resnet_time_scale_shift, | |
| non_linearity=resnet_act_fn, | |
| output_scale_factor=output_scale_factor, | |
| pre_norm=resnet_pre_norm, | |
| ) | |
| ) | |
| if add_cross_attention: | |
| modules.append( | |
| FlexibleTransformer2DModel( | |
| num_attention_heads, | |
| out_channels // num_attention_heads, | |
| in_channels=out_channels, | |
| num_layers=transformer_layers_per_block, | |
| cross_attention_dim=cross_attention_dim, | |
| norm_num_groups=resnet_groups, | |
| use_linear_projection=use_linear_projection, | |
| only_cross_attention=only_cross_attention, | |
| upcast_attention=upcast_attention, | |
| ) | |
| ) | |
| if not mix_block_in_forward: | |
| modules = sorted(modules, key=custom_sort_order) | |
| self.modules_list = nn.ModuleList(modules) | |
| self.upsamplers = None | |
| if add_upsample: | |
| self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) | |
| self.gradient_checkpointing = False | |
| def forward( | |
| self, | |
| hidden_states: torch.FloatTensor, | |
| res_hidden_states_tuple: Tuple[torch.FloatTensor, ...], | |
| temb: Optional[torch.FloatTensor] = None, | |
| encoder_hidden_states: Optional[torch.FloatTensor] = None, | |
| cross_attention_kwargs: Optional[Dict[str, Any]] = None, | |
| upsample_size: Optional[int] = None, | |
| attention_mask: Optional[torch.FloatTensor] = None, | |
| encoder_attention_mask: Optional[torch.FloatTensor] = None, | |
| ): | |
| for module in self.modules_list: | |
| if isinstance(module, ResnetBlock2D): | |
| res_hidden_states = res_hidden_states_tuple[-1] | |
| res_hidden_states_tuple = res_hidden_states_tuple[:-1] | |
| hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) | |
| hidden_states = module(hidden_states, temb) | |
| if isinstance(module, (FlexibleTransformer2DModel, Transformer2DModel)): | |
| hidden_states = module( | |
| hidden_states, | |
| encoder_hidden_states=encoder_hidden_states, | |
| cross_attention_kwargs=cross_attention_kwargs, | |
| attention_mask=attention_mask, | |
| encoder_attention_mask=encoder_attention_mask, | |
| return_dict=False, | |
| )[0] | |
| if self.upsamplers is not None: | |
| for upsampler in self.upsamplers: | |
| hidden_states = upsampler(hidden_states, upsample_size) | |
| return hidden_states | |
| class FlexibleUNetMidBlock2DCrossAttn(nn.Module): | |
| def __init__( | |
| self, | |
| in_channels: int, | |
| temb_channels: int, | |
| dropout: float = 0.0, | |
| num_resnets: int = 1, | |
| num_attentions: int = 1, | |
| transformer_layers_per_block: int = 1, | |
| resnet_eps: float = 1e-6, | |
| resnet_time_scale_shift: str = "default", | |
| resnet_act_fn: str = "swish", | |
| resnet_groups: int = 32, | |
| resnet_pre_norm: bool = True, | |
| num_attention_heads: int = 1, | |
| output_scale_factor: float = 1.0, | |
| cross_attention_dim: int = 1280, | |
| use_linear_projection: bool = False, | |
| upcast_attention: bool = False, | |
| mix_block_in_forward: bool = True, | |
| add_upsample: bool = True, | |
| ): | |
| super().__init__() | |
| self.has_cross_attention = True | |
| self.num_attention_heads = num_attention_heads | |
| resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) | |
| # There is always at least one resnet | |
| modules = [ | |
| ResnetBlock2D( | |
| in_channels=in_channels, | |
| out_channels=in_channels, | |
| temb_channels=temb_channels, | |
| eps=resnet_eps, | |
| groups=resnet_groups, | |
| dropout=dropout, | |
| time_embedding_norm=resnet_time_scale_shift, | |
| non_linearity=resnet_act_fn, | |
| output_scale_factor=output_scale_factor, | |
| pre_norm=resnet_pre_norm, | |
| ) | |
| ] | |
| add_resnets = [True] * num_resnets | |
| add_cross_attentions = [True] * num_attentions | |
| for i, (add_resnet, add_cross_attention) in enumerate(itertools.zip_longest(add_resnets, add_cross_attentions, fillvalue=False)): | |
| if add_cross_attention: | |
| modules.append( | |
| FlexibleTransformer2DModel( | |
| num_attention_heads, | |
| in_channels // num_attention_heads, | |
| in_channels=in_channels, | |
| num_layers=transformer_layers_per_block, | |
| cross_attention_dim=cross_attention_dim, | |
| norm_num_groups=resnet_groups, | |
| use_linear_projection=use_linear_projection, | |
| upcast_attention=upcast_attention, | |
| ) | |
| ) | |
| if add_resnet: | |
| modules.append( | |
| ResnetBlock2D( | |
| in_channels=in_channels, | |
| out_channels=in_channels, | |
| temb_channels=temb_channels, | |
| eps=resnet_eps, | |
| groups=resnet_groups, | |
| dropout=dropout, | |
| time_embedding_norm=resnet_time_scale_shift, | |
| non_linearity=resnet_act_fn, | |
| output_scale_factor=output_scale_factor, | |
| pre_norm=resnet_pre_norm, | |
| ) | |
| ) | |
| if not mix_block_in_forward: | |
| modules = sorted(modules, key=custom_sort_order) | |
| self.modules_list = nn.ModuleList(modules) | |
| self.upsamplers = nn.ModuleList([nn.Identity()]) | |
| if add_upsample: | |
| self.upsamplers = nn.ModuleList([Upsample2D(in_channels, use_conv=True, out_channels=in_channels)]) | |
| def forward( | |
| self, | |
| hidden_states: torch.FloatTensor, | |
| temb: Optional[torch.FloatTensor] = None, | |
| encoder_hidden_states: Optional[torch.FloatTensor] = None, | |
| attention_mask: Optional[torch.FloatTensor] = None, | |
| cross_attention_kwargs: Optional[Dict[str, Any]] = None, | |
| encoder_attention_mask: Optional[torch.FloatTensor] = None, | |
| ) -> torch.FloatTensor: | |
| hidden_states = self.modules_list[0](hidden_states, temb) | |
| for module in self.modules_list: | |
| if isinstance(module, (FlexibleTransformer2DModel, Transformer2DModel)): | |
| hidden_states = module( | |
| hidden_states, | |
| encoder_hidden_states=encoder_hidden_states, | |
| cross_attention_kwargs=cross_attention_kwargs, | |
| attention_mask=attention_mask, | |
| encoder_attention_mask=encoder_attention_mask, | |
| return_dict=False, | |
| )[0] | |
| elif isinstance(module, ResnetBlock2D): | |
| hidden_states = module(hidden_states, temb) | |
| for upsampler in self.upsamplers: | |
| hidden_states = upsampler(hidden_states) | |
| return hidden_states | |
| class FlexibleTransformer2DModel(ModelMixin, ConfigMixin): | |
| def __init__( | |
| self, | |
| num_attention_heads: int = 16, | |
| attention_head_dim: int = 88, | |
| in_channels: Optional[int] = None, | |
| out_channels: Optional[int] = None, | |
| num_layers: int = 1, | |
| dropout: float = 0.0, | |
| norm_num_groups: int = 32, | |
| cross_attention_dim: Optional[int] = None, | |
| attention_bias: bool = False, | |
| activation_fn: str = "geglu", | |
| num_embeds_ada_norm: Optional[int] = None, | |
| only_cross_attention: bool = False, | |
| use_linear_projection: bool = False, | |
| upcast_attention: bool = False, | |
| norm_type: str = "layer_norm", | |
| norm_elementwise_affine: bool = True, | |
| ): | |
| super().__init__() | |
| self.num_attention_heads = num_attention_heads | |
| self.attention_head_dim = attention_head_dim | |
| self.in_channels = in_channels | |
| inner_dim = num_attention_heads * attention_head_dim | |
| # Define input layers | |
| self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True) | |
| self.use_linear_projection = use_linear_projection | |
| if self.use_linear_projection: | |
| self.proj_in = nn.Linear(in_channels, inner_dim) | |
| else: | |
| self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0) | |
| # Define transformers blocks | |
| self.transformer_blocks = nn.ModuleList( | |
| [ | |
| BasicTransformerBlock( | |
| inner_dim, | |
| num_attention_heads, | |
| attention_head_dim, | |
| dropout=dropout, | |
| cross_attention_dim=cross_attention_dim, | |
| activation_fn=activation_fn, | |
| num_embeds_ada_norm=num_embeds_ada_norm, | |
| attention_bias=attention_bias, | |
| only_cross_attention=only_cross_attention, | |
| upcast_attention=upcast_attention, | |
| norm_type=norm_type, | |
| norm_elementwise_affine=norm_elementwise_affine, | |
| ) | |
| for _ in range(num_layers) | |
| ] | |
| ) | |
| # Define output layers | |
| self.out_channels = in_channels if out_channels is None else out_channels | |
| if self.use_linear_projection: | |
| self.proj_out = nn.Linear(inner_dim, in_channels) | |
| else: | |
| self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0) | |
| def forward( | |
| self, | |
| hidden_states: torch.Tensor, | |
| encoder_hidden_states: Optional[torch.Tensor] = None, | |
| timestep: Optional[torch.LongTensor] = None, | |
| class_labels: Optional[torch.LongTensor] = None, | |
| cross_attention_kwargs: Dict[str, Any] = None, | |
| attention_mask: Optional[torch.Tensor] = None, | |
| encoder_attention_mask: Optional[torch.Tensor] = None, | |
| return_dict: bool = False, | |
| ): | |
| # 1. Input | |
| batch, _, height, width = hidden_states.shape | |
| residual = hidden_states | |
| hidden_states = self.norm(hidden_states) | |
| if not self.use_linear_projection: | |
| hidden_states = self.proj_in(hidden_states) | |
| inner_dim = hidden_states.shape[1] | |
| hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim) | |
| else: | |
| inner_dim = hidden_states.shape[1] | |
| hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim) | |
| hidden_states = self.proj_in(hidden_states) | |
| # 2. Blocks | |
| for block in self.transformer_blocks: | |
| hidden_states = block( | |
| hidden_states, | |
| attention_mask=attention_mask, | |
| encoder_hidden_states=encoder_hidden_states, | |
| encoder_attention_mask=encoder_attention_mask, | |
| timestep=timestep, | |
| cross_attention_kwargs=cross_attention_kwargs, | |
| class_labels=class_labels, | |
| ) | |
| # 3. Output | |
| if not self.use_linear_projection: | |
| hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous() | |
| hidden_states = self.proj_out(hidden_states) | |
| else: | |
| hidden_states = self.proj_out(hidden_states) | |
| hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous() | |
| output = hidden_states + residual | |
| if return_dict: | |
| return (output,) | |
| return Transformer2DModelOutput(sample=output) | |
| class DeciDiffusionPipeline(StableDiffusionPipeline): | |
| deci_default_squeeze_mode = "10,6" | |
| deci_default_number_of_iterations = 16 | |
| deci_default_guidance_rescale = 0.8 | |
| def __init__( | |
| self, | |
| vae: AutoencoderKL, | |
| text_encoder: CLIPTextModel, | |
| tokenizer: CLIPTokenizer, | |
| unet: UNet2DConditionModel, | |
| scheduler: KarrasDiffusionSchedulers, | |
| safety_checker: StableDiffusionSafetyChecker, | |
| feature_extractor: CLIPImageProcessor, | |
| requires_safety_checker: bool = True, | |
| ): | |
| # Replace UNet with Deci`s unet | |
| del unet | |
| unet = FlexibleUNet2DConditionModel() | |
| # Replace with custom scheduler | |
| del scheduler | |
| scheduler = SqueezedDPMSolverMultistepScheduler(squeeze_mode=self.deci_default_squeeze_mode) | |
| super().__init__( | |
| vae=vae, | |
| text_encoder=text_encoder, | |
| tokenizer=tokenizer, | |
| unet=unet, | |
| scheduler=scheduler, | |
| safety_checker=safety_checker, | |
| feature_extractor=feature_extractor, | |
| requires_safety_checker=requires_safety_checker, | |
| ) | |
| self.register_modules( | |
| vae=vae, | |
| text_encoder=text_encoder, | |
| tokenizer=tokenizer, | |
| unet=unet, | |
| scheduler=scheduler, | |
| safety_checker=safety_checker, | |
| feature_extractor=feature_extractor, | |
| ) | |
| def __call__( | |
| self, | |
| prompt: Union[str, List[str]] = None, | |
| height: Optional[int] = None, | |
| width: Optional[int] = None, | |
| num_inference_steps: int = 16, | |
| guidance_scale: float = 7.5, | |
| negative_prompt: Optional[Union[str, List[str]]] = None, | |
| num_images_per_prompt: Optional[int] = 1, | |
| eta: float = 0.0, | |
| generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, | |
| latents: Optional[torch.FloatTensor] = None, | |
| prompt_embeds: Optional[torch.FloatTensor] = None, | |
| negative_prompt_embeds: Optional[torch.FloatTensor] = None, | |
| output_type: Optional[str] = "pil", | |
| return_dict: bool = True, | |
| callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, | |
| callback_steps: int = 1, | |
| cross_attention_kwargs: Optional[Dict[str, Any]] = None, | |
| guidance_rescale: float = 0.8, | |
| ): | |
| r""" | |
| The call function to the pipeline for generation. | |
| Args: | |
| prompt (`str` or `List[str]`, *optional*): | |
| The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. | |
| height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): | |
| The height in pixels of the generated image. | |
| width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): | |
| The width in pixels of the generated image. | |
| num_inference_steps (`int`, *optional*, defaults to 50): | |
| The number of denoising steps. More denoising steps usually lead to a higher quality image at the | |
| expense of slower inference. | |
| guidance_scale (`float`, *optional*, defaults to 7.5): | |
| A higher guidance scale value encourages the model to generate images closely linked to the text | |
| `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. | |
| negative_prompt (`str` or `List[str]`, *optional*): | |
| The prompt or prompts to guide what to not include in image generation. If not defined, you need to | |
| pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). | |
| num_images_per_prompt (`int`, *optional*, defaults to 1): | |
| The number of images to generate per prompt. | |
| eta (`float`, *optional*, defaults to 0.0): | |
| Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies | |
| to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. | |
| generator (`torch.Generator` or `List[torch.Generator]`, *optional*): | |
| A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make | |
| generation deterministic. | |
| latents (`torch.FloatTensor`, *optional*): | |
| Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image | |
| generation. Can be used to tweak the same generation with different prompts. If not provided, a latents | |
| tensor is generated by sampling using the supplied random `generator`. | |
| prompt_embeds (`torch.FloatTensor`, *optional*): | |
| Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not | |
| provided, text embeddings are generated from the `prompt` input argument. | |
| negative_prompt_embeds (`torch.FloatTensor`, *optional*): | |
| Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If | |
| not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. | |
| output_type (`str`, *optional*, defaults to `"pil"`): | |
| The output format of the generated image. Choose between `PIL.Image` or `np.array`. | |
| return_dict (`bool`, *optional*, defaults to `True`): | |
| Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a | |
| plain tuple. | |
| callback (`Callable`, *optional*): | |
| A function that calls every `callback_steps` steps during inference. The function is called with the | |
| following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. | |
| callback_steps (`int`, *optional*, defaults to 1): | |
| The frequency at which the `callback` function is called. If not specified, the callback is called at | |
| every step. | |
| cross_attention_kwargs (`dict`, *optional*): | |
| A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in | |
| [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). | |
| guidance_rescale (`float`, *optional*, defaults to 0.7): | |
| Guidance rescale factor from [Common Diffusion Noise Schedules and Sample Steps are | |
| Flawed](https://arxiv.org/pdf/2305.08891.pdf). Guidance rescale factor should fix overexposure when | |
| using zero terminal SNR. | |
| Examples: | |
| Returns: | |
| [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: | |
| If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, | |
| otherwise a `tuple` is returned where the first element is a list with the generated images and the | |
| second element is a list of `bool`s indicating whether the corresponding generated image contains | |
| "not-safe-for-work" (nsfw) content. | |
| """ | |
| # 0. Default height and width to unet | |
| height = height or self.unet.config.sample_size * self.vae_scale_factor | |
| width = width or self.unet.config.sample_size * self.vae_scale_factor | |
| # 1. Check inputs. Raise error if not correct | |
| self.check_inputs(prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds) | |
| # 2. Define call parameters | |
| if prompt is not None and isinstance(prompt, str): | |
| batch_size = 1 | |
| elif prompt is not None and isinstance(prompt, list): | |
| batch_size = len(prompt) | |
| else: | |
| batch_size = prompt_embeds.shape[0] | |
| device = self._execution_device | |
| # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) | |
| # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` | |
| # corresponds to doing no classifier free guidance. | |
| do_classifier_free_guidance = guidance_scale > 1.0 | |
| # 3. Encode input prompt | |
| text_encoder_lora_scale = cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None | |
| prompt_embeds, negative_prompt_embeds = self.encode_prompt( | |
| prompt, | |
| device, | |
| num_images_per_prompt, | |
| do_classifier_free_guidance, | |
| negative_prompt, | |
| prompt_embeds=prompt_embeds, | |
| negative_prompt_embeds=negative_prompt_embeds, | |
| lora_scale=text_encoder_lora_scale, | |
| ) | |
| # For classifier free guidance, we need to do two forward passes. | |
| # Here we concatenate the unconditional and text embeddings into a single batch | |
| # to avoid doing two forward passes | |
| if do_classifier_free_guidance: | |
| prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) | |
| # 4. Prepare timesteps | |
| self.scheduler.set_timesteps(num_inference_steps, device=device) | |
| timesteps = self.scheduler.timesteps | |
| # 5. Prepare latent variables | |
| num_channels_latents = self.unet.config.in_channels | |
| latents = self.prepare_latents( | |
| batch_size * num_images_per_prompt, | |
| num_channels_latents, | |
| height, | |
| width, | |
| prompt_embeds.dtype, | |
| device, | |
| generator, | |
| latents, | |
| ) | |
| # 6. Prepare extra step kwargs. | |
| extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) | |
| # 7. Denoising loop | |
| num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order | |
| with self.progress_bar(total=len(timesteps)) as progress_bar: | |
| for i, t in enumerate(timesteps): | |
| # expand the latents if we are doing classifier free guidance | |
| latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents | |
| latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) | |
| # predict the noise residual | |
| noise_pred = self.unet( | |
| latent_model_input, | |
| t, | |
| encoder_hidden_states=prompt_embeds, | |
| cross_attention_kwargs=cross_attention_kwargs, | |
| return_dict=False, | |
| )[0] | |
| # perform guidance | |
| if do_classifier_free_guidance: | |
| noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) | |
| noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) | |
| if do_classifier_free_guidance and guidance_rescale > 0.0: | |
| # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf | |
| noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) | |
| # compute the previous noisy sample x_t -> x_t-1 | |
| latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] | |
| # call the callback, if provided | |
| if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): | |
| progress_bar.update() | |
| if callback is not None and i % callback_steps == 0: | |
| callback(i, t, latents) | |
| if not output_type == "latent": | |
| image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] | |
| image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) | |
| else: | |
| image = latents | |
| has_nsfw_concept = None | |
| if has_nsfw_concept is None: | |
| do_denormalize = [True] * image.shape[0] | |
| else: | |
| do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] | |
| image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) | |
| # Offload all models | |
| self.maybe_free_model_hooks() | |
| if not return_dict: | |
| return (image, has_nsfw_concept) | |
| return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) | |