| | |
| | |
| | |
| | |
| | """Distilled Audio State-Space Model (DASS) model""" |
| |
|
| | import math |
| | import torch |
| | import warnings |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | import torch.utils.checkpoint as checkpoint |
| | from timm.models.layers import DropPath, trunc_normal_ |
| | from functools import partial |
| | from typing import Optional, Callable, Any, Union |
| | from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss |
| | from transformers.modeling_outputs import SequenceClassifierOutput |
| |
|
| | from transformers.utils import logging |
| | from transformers.modeling_utils import PreTrainedModel |
| |
|
| | from .configuration_dass import DASSConfig |
| |
|
| | logger = logging.get_logger(__name__) |
| |
|
| | |
| | _CONFIG_FOR_DOC = "DASSConfig" |
| |
|
| | WITH_TRITON = True |
| | |
| | try: |
| | import triton |
| | import triton.language as tl |
| | except: |
| | WITH_TRITON = False |
| | warnings.warn("Triton not installed, fall back to pytorch implements.") |
| |
|
| | |
| | if WITH_TRITON: |
| | try: |
| | from functools import cached_property |
| | except: |
| | warnings.warn("if you are using py37, add this line to functools.py: " |
| | "cached_property = lambda func: property(lru_cache()(func))") |
| |
|
| | |
| | def cross_scan_fwd(x: torch.Tensor, in_channel_first=True, out_channel_first=True, scans=0): |
| | if in_channel_first: |
| | B, C, H, W = x.shape |
| | if scans == 0: |
| | y = x.new_empty((B, 4, C, H * W)) |
| | y[:, 0, :, :] = x.flatten(2, 3) |
| | y[:, 1, :, :] = x.transpose(dim0=2, dim1=3).flatten(2, 3) |
| | y[:, 2:4, :, :] = torch.flip(y[:, 0:2, :, :], dims=[-1]) |
| | elif scans == 1: |
| | y = x.view(B, 1, C, H * W).repeat(1, 4, 1, 1) |
| | elif scans == 2: |
| | y = x.view(B, 1, C, H * W).repeat(1, 2, 1, 1) |
| | y = torch.cat([y, y.flip(dims=[-1])], dim=1) |
| | elif scans == 3: |
| | y = x.new_empty((B, 4, C, H * W)) |
| | y[:, 0, :, :] = x.flatten(2, 3) |
| | y[:, 1, :, :] = torch.rot90(x, 1, dims=(2, 3)).flatten(2, 3) |
| | y[:, 2, :, :] = torch.rot90(x, 2, dims=(2, 3)).flatten(2, 3) |
| | y[:, 3, :, :] = torch.rot90(x, 3, dims=(2, 3)).flatten(2, 3) |
| | else: |
| | B, H, W, C = x.shape |
| | if scans == 0: |
| | y = x.new_empty((B, H * W, 4, C)) |
| | y[:, :, 0, :] = x.flatten(1, 2) |
| | y[:, :, 1, :] = x.transpose(dim0=1, dim1=2).flatten(1, 2) |
| | y[:, :, 2:4, :] = torch.flip(y[:, :, 0:2, :], dims=[1]) |
| | elif scans == 1: |
| | y = x.view(B, H * W, 1, C).repeat(1, 1, 4, 1) |
| | elif scans == 2: |
| | y = x.view(B, H * W, 1, C).repeat(1, 1, 2, 1) |
| | y = torch.cat([y, y.flip(dims=[1])], dim=2) |
| | elif scans == 3: |
| | y = x.new_empty((B, H * W, 4, C)) |
| | y[:, :, 0, :] = x.flatten(1, 2) |
| | y[:, :, 1, :] = torch.rot90(x, 1, dims=(1, 2)).flatten(1, 2) |
| | y[:, :, 2, :] = torch.rot90(x, 2, dims=(1, 2)).flatten(1, 2) |
| | y[:, :, 3, :] = torch.rot90(x, 3, dims=(1, 2)).flatten(1, 2) |
| |
|
| | if in_channel_first and (not out_channel_first): |
| | y = y.permute(0, 3, 1, 2).contiguous() |
| | elif (not in_channel_first) and out_channel_first: |
| | y = y.permute(0, 2, 3, 1).contiguous() |
| |
|
| | return y |
| |
|
| |
|
| | def cross_merge_fwd(y: torch.Tensor, in_channel_first=True, out_channel_first=True, scans=0): |
| | if out_channel_first: |
| | B, K, D, H, W = y.shape |
| | y = y.view(B, K, D, -1) |
| | if scans == 0: |
| | y = y[:, 0:2] + y[:, 2:4].flip(dims=[-1]).view(B, 2, D, -1) |
| | y = y[:, 0] + y[:, 1].view(B, -1, W, H).transpose(dim0=2, dim1=3).contiguous().view(B, D, -1) |
| | elif scans == 1: |
| | y = y.sum(1) |
| | elif scans == 2: |
| | y = y[:, 0:2] + y[:, 2:4].flip(dims=[-1]).view(B, 2, D, -1) |
| | y = y.sum(1) |
| | elif scans == 3: |
| | oy = y[:, 0, :, :].contiguous().view(B, D, -1) |
| | oy = oy + torch.rot90(y.view(B, K, D, W, H)[:, 1, :, :, :], -1, dims=(2, 3)).flatten(2, 3) |
| | oy = oy + torch.rot90(y.view(B, K, D, H, W)[:, 2, :, :, :], -2, dims=(2, 3)).flatten(2, 3) |
| | oy = oy + torch.rot90(y.view(B, K, D, W, H)[:, 3, :, :, :], -3, dims=(2, 3)).flatten(2, 3) |
| | y = oy |
| | else: |
| | B, H, W, K, D = y.shape |
| | y = y.view(B, -1, K, D) |
| | if scans == 0: |
| | y = y[:, :, 0:2] + y[:, :, 2:4].flip(dims=[1]).view(B, -1, 2, D) |
| | y = y[:, :, 0] + y[:, :, 1].view(B, W, H, -1).transpose(dim0=1, dim1=2).contiguous().view(B, -1, D) |
| | elif scans == 1: |
| | y = y.sum(2) |
| | elif scans == 2: |
| | y = y[:, :, 0:2] + y[:, :, 2:4].flip(dims=[1]).view(B, -1, 2, D) |
| | y = y.sum(2) |
| | elif scans == 3: |
| | oy = y[:, :, 0, :].contiguous().view(B, -1, D) |
| | oy = oy + torch.rot90(y.view(B, W, H, K, D)[:, :, :, 1, :], -1, dims=(1, 2)).flatten(1, 2) |
| | oy = oy + torch.rot90(y.view(B, H, W, K, D)[:, :, :, 2, :], -2, dims=(1, 2)).flatten(1, 2) |
| | oy = oy + torch.rot90(y.view(B, W, H, K, D)[:, :, :, 3, :], -3, dims=(1, 2)).flatten(1, 2) |
| | y = oy |
| | |
| | if in_channel_first and (not out_channel_first): |
| | y = y.permute(0, 2, 1).contiguous() |
| | elif (not in_channel_first) and out_channel_first: |
| | y = y.permute(0, 2, 1).contiguous() |
| | |
| | return y |
| |
|
| |
|
| | def cross_scan1b1_fwd(x: torch.Tensor, in_channel_first=True, out_channel_first=True, scans=0): |
| | if in_channel_first: |
| | B, _, C, H, W = x.shape |
| | if scans == 0: |
| | y = torch.stack([ |
| | x[:, 0].flatten(2, 3), |
| | x[:, 1].transpose(dim0=2, dim1=3).flatten(2, 3), |
| | torch.flip(x[:, 2].flatten(2, 3), dims=[-1]), |
| | torch.flip(x[:, 3].transpose(dim0=2, dim1=3).flatten(2, 3), dims=[-1]), |
| | ], dim=1) |
| | elif scans == 1: |
| | y = x.flatten(2, 3) |
| | elif scans == 2: |
| | y = torch.stack([ |
| | x[:, 0].flatten(2, 3), |
| | x[:, 1].flatten(2, 3), |
| | torch.flip(x[:, 2].flatten(2, 3), dims=[-1]), |
| | torch.flip(x[:, 3].flatten(2, 3), dims=[-1]), |
| | ], dim=1) |
| | elif scans == 3: |
| | y = torch.stack([ |
| | x[:, 0, :, :, :].flatten(2, 3), |
| | torch.rot90(x[:, 1, :, :, :], 1, dims=(2, 3)).flatten(2, 3), |
| | torch.rot90(x[:, 2, :, :, :], 2, dims=(2, 3)).flatten(2, 3), |
| | torch.rot90(x[:, 3, :, :, :], 3, dims=(2, 3)).flatten(2, 3), |
| | ], dim=1) |
| |
|
| | else: |
| | B, H, W, _, C = x.shape |
| | if scans == 0: |
| | y = torch.stack([ |
| | x[:, :, :, 0].flatten(1, 2), |
| | x[:, :, :, 1].transpose(dim0=1, dim1=2).flatten(1, 2), |
| | torch.flip(x[:, :, :, 2].flatten(1, 2), dims=[1]), |
| | torch.flip(x[:, :, :, 3].transpose(dim0=1, dim1=2).flatten(1, 2), dims=[1]), |
| | ], dim=2) |
| | elif scans == 1: |
| | y = x.flatten(1, 2) |
| | elif scans == 2: |
| | y = torch.stack([ |
| | x[:, 0].flatten(1, 2), |
| | x[:, 1].flatten(1, 2), |
| | torch.flip(x[:, 2].flatten(1, 2), dims=[-1]), |
| | torch.flip(x[:, 3].flatten(1, 2), dims=[-1]), |
| | ], dim=2) |
| | elif scans == 3: |
| | y = torch.stack([ |
| | x[:, :, :, 0, :].flatten(1, 2), |
| | torch.rot90(x[:, :, :, 1, :], 1, dims=(1, 2)).flatten(1, 2), |
| | torch.rot90(x[:, :, :, 2, :], 2, dims=(1, 2)).flatten(1, 2), |
| | torch.rot90(x[:, :, :, 3, :], 3, dims=(1, 2)).flatten(1, 2), |
| | ], dim=1) |
| |
|
| | if in_channel_first and (not out_channel_first): |
| | y = y.permute(0, 3, 1, 2).contiguous() |
| | elif (not in_channel_first) and out_channel_first: |
| | y = y.permute(0, 2, 3, 1).contiguous() |
| |
|
| | return y |
| |
|
| |
|
| | def cross_merge1b1_fwd(y: torch.Tensor, in_channel_first=True, out_channel_first=True, scans=0): |
| | if out_channel_first: |
| | B, K, D, H, W = y.shape |
| | y = y.view(B, K, D, -1) |
| | if scans == 0: |
| | y = torch.stack([ |
| | y[:, 0], |
| | y[:, 1].view(B, -1, W, H).transpose(dim0=2, dim1=3).flatten(2, 3), |
| | torch.flip(y[:, 2], dims=[-1]), |
| | torch.flip(y[:, 3].view(B, -1, W, H).transpose(dim0=2, dim1=3).flatten(2, 3), dims=[-1]), |
| | ], dim=1) |
| | elif scans == 1: |
| | y = y |
| | elif scans == 2: |
| | y = torch.stack([ |
| | y[:, 0], |
| | y[:, 1], |
| | torch.flip(y[:, 2], dims=[-1]), |
| | torch.flip(y[:, 3], dims=[-1]), |
| | ], dim=1) |
| | elif scans == 3: |
| | y = torch.stack([ |
| | y[:, 0, :, :].contiguous().view(B, D, -1), |
| | torch.rot90(y.view(B, K, D, W, H)[:, 1, :, :, :], -1, dims=(2, 3)).flatten(2, 3), |
| | torch.rot90(y.view(B, K, D, H, W)[:, 2, :, :, :], -2, dims=(2, 3)).flatten(2, 3), |
| | torch.rot90(y.view(B, K, D, W, H)[:, 3, :, :, :], -3, dims=(2, 3)).flatten(2, 3), |
| | ], dim=1) |
| | else: |
| | B, H, W, K, D = y.shape |
| | y = y.view(B, -1, K, D) |
| | if scans == 0: |
| | y = torch.stack([ |
| | y[:, :, 0], |
| | y[:, :, 1].view(B, W, H, -1).transpose(dim0=1, dim1=2).flatten(1, 2), |
| | torch.flip(y[:, :, 2], dims=[1]), |
| | torch.flip(y[:, :, 3].view(B, W, H, -1).transpose(dim0=1, dim1=2).flatten(1, 2), dims=[1]), |
| | ], dim=2) |
| | elif scans == 1: |
| | y = y |
| | elif scans == 2: |
| | y = torch.stack([ |
| | y[:, :, 0], |
| | y[:, :, 1], |
| | torch.flip(y[:, :, 2], dims=[1]), |
| | torch.flip(y[:, :, 3], dims=[1]), |
| | ], dim=2) |
| | elif scans == 3: |
| | y = torch.stack([ |
| | y[:, :, 0, :].contiguous().view(B, -1, D), |
| | torch.rot90(y.view(B, W, H, K, D)[:, :, :, 1, :], -1, dims=(1, 2)).flatten(1, 2), |
| | torch.rot90(y.view(B, H, W, K, D)[:, :, :, 2, :], -2, dims=(1, 2)).flatten(1, 2), |
| | torch.rot90(y.view(B, W, H, K, D)[:, :, :, 3, :], -3, dims=(1, 2)).flatten(1, 2), |
| | ], dim=2) |
| |
|
| | if out_channel_first and (not in_channel_first): |
| | y = y.permute(0, 3, 1, 2).contiguous() |
| | elif (not out_channel_first) and in_channel_first: |
| | y = y.permute(0, 2, 3, 1).contiguous() |
| |
|
| | return y |
| |
|
| |
|
| | class CrossScanF(torch.autograd.Function): |
| | @staticmethod |
| | def forward(ctx, x: torch.Tensor, in_channel_first=True, out_channel_first=True, one_by_one=False, scans=0): |
| | |
| | |
| | ctx.in_channel_first = in_channel_first |
| | ctx.out_channel_first = out_channel_first |
| | ctx.one_by_one = one_by_one |
| | ctx.scans = scans |
| |
|
| | if one_by_one: |
| | B, K, C, H, W = x.shape |
| | if not in_channel_first: |
| | B, H, W, K, C = x.shape |
| | else: |
| | B, C, H, W = x.shape |
| | if not in_channel_first: |
| | B, H, W, C = x.shape |
| | ctx.shape = (B, C, H, W) |
| |
|
| | _fn = cross_scan1b1_fwd if one_by_one else cross_scan_fwd |
| | y = _fn(x, in_channel_first, out_channel_first, scans) |
| |
|
| | return y |
| | |
| | @staticmethod |
| | def backward(ctx, ys: torch.Tensor): |
| | |
| | in_channel_first = ctx.in_channel_first |
| | out_channel_first = ctx.out_channel_first |
| | one_by_one = ctx.one_by_one |
| | scans = ctx.scans |
| | B, C, H, W = ctx.shape |
| |
|
| | ys = ys.view(B, -1, C, H, W) if out_channel_first else ys.view(B, H, W, -1, C) |
| | _fn = cross_merge1b1_fwd if one_by_one else cross_merge_fwd |
| | y = _fn(ys, in_channel_first, out_channel_first, scans) |
| | |
| | if one_by_one: |
| | y = y.view(B, 4, -1, H, W) if in_channel_first else y.view(B, H, W, 4, -1) |
| | else: |
| | y = y.view(B, -1, H, W) if in_channel_first else y.view(B, H, W, -1) |
| |
|
| | return y, None, None, None, None |
| |
|
| |
|
| | class CrossMergeF(torch.autograd.Function): |
| | @staticmethod |
| | def forward(ctx, ys: torch.Tensor, in_channel_first=True, out_channel_first=True, one_by_one=False, scans=0): |
| | |
| | |
| | ctx.in_channel_first = in_channel_first |
| | ctx.out_channel_first = out_channel_first |
| | ctx.one_by_one = one_by_one |
| | ctx.scans = scans |
| |
|
| | B, K, C, H, W = ys.shape |
| | if not out_channel_first: |
| | B, H, W, K, C = ys.shape |
| | ctx.shape = (B, C, H, W) |
| | |
| | _fn = cross_merge1b1_fwd if one_by_one else cross_merge_fwd |
| | y = _fn(ys, in_channel_first, out_channel_first, scans) |
| |
|
| | return y |
| | |
| | @staticmethod |
| | def backward(ctx, x: torch.Tensor): |
| | |
| | |
| | in_channel_first = ctx.in_channel_first |
| | out_channel_first = ctx.out_channel_first |
| | one_by_one = ctx.one_by_one |
| | scans = ctx.scans |
| | B, C, H, W = ctx.shape |
| | |
| | if not one_by_one: |
| | if in_channel_first: |
| | x = x.view(B, C, H, W) |
| | else: |
| | x = x.view(B, H, W, C) |
| | else: |
| | if in_channel_first: |
| | x = x.view(B, 4, C, H, W) |
| | else: |
| | x = x.view(B, H, W, 4, C) |
| | |
| | _fn = cross_scan1b1_fwd if one_by_one else cross_scan_fwd |
| | x = _fn(x, in_channel_first, out_channel_first, scans) |
| | x = x.view(B, 4, C, H, W) if out_channel_first else x.view(B, H, W, 4, C) |
| | |
| | return x, None, None, None, None |
| |
|
| |
|
| | |
| |
|
| | @triton.jit |
| | def triton_cross_scan_flex( |
| | x: tl.tensor, |
| | y: tl.tensor, |
| | x_layout: tl.constexpr, |
| | y_layout: tl.constexpr, |
| | operation: tl.constexpr, |
| | onebyone: tl.constexpr, |
| | scans: tl.constexpr, |
| | BC: tl.constexpr, |
| | BH: tl.constexpr, |
| | BW: tl.constexpr, |
| | DC: tl.constexpr, |
| | DH: tl.constexpr, |
| | DW: tl.constexpr, |
| | NH: tl.constexpr, |
| | NW: tl.constexpr, |
| | ): |
| | |
| | |
| | |
| | |
| | |
| |
|
| | i_hw, i_c, i_b = tl.program_id(0), tl.program_id(1), tl.program_id(2) |
| | i_h, i_w = (i_hw // NW), (i_hw % NW) |
| | _mask_h = (i_h * BH + tl.arange(0, BH)) < DH |
| | _mask_w = (i_w * BW + tl.arange(0, BW)) < DW |
| | _mask_hw = _mask_h[:, None] & _mask_w[None, :] |
| | _for_C = min(DC - i_c * BC, BC) |
| |
|
| | pos_h = (i_h * BH + tl.arange(0, BH)[:, None]) |
| | pos_w = (i_w * BW + tl.arange(0, BW)[None, :]) |
| | neg_h = (DH - i_h * BH - 1 - tl.arange(0, BH)[:, None]) |
| | neg_w = (DW - i_w * BW - 1 - tl.arange(0, BW)[None, :]) |
| | if scans == 0: |
| | |
| | HWRoute0 = pos_h * DW + pos_w |
| | HWRoute1 = pos_w * DH + pos_h |
| | HWRoute2 = neg_h * DW + neg_w |
| | HWRoute3 = neg_w * DH + neg_h |
| | elif scans == 1: |
| | |
| | HWRoute0 = pos_h * DW + pos_w |
| | HWRoute1 = HWRoute0 |
| | HWRoute2 = HWRoute0 |
| | HWRoute3 = HWRoute0 |
| | elif scans == 2: |
| | |
| | HWRoute0 = pos_h * DW + pos_w |
| | HWRoute1 = HWRoute0 |
| | HWRoute2 = neg_h * DW + neg_w |
| | HWRoute3 = HWRoute2 |
| | elif scans == 3: |
| | |
| | HWRoute0 = pos_h * DW + pos_w |
| | HWRoute1 = neg_w * DH + pos_h |
| | HWRoute2 = neg_h * DW + neg_w |
| | HWRoute3 = pos_w * DH + neg_h |
| |
|
| | _tmp1 = DC * DH * DW |
| |
|
| | y_ptr_base = y + i_b * 4 * _tmp1 + (i_c * BC * DH * DW if y_layout == 0 else i_c * BC) |
| | if y_layout == 0: |
| | p_y1 = y_ptr_base + HWRoute0 |
| | p_y2 = y_ptr_base + _tmp1 + HWRoute1 |
| | p_y3 = y_ptr_base + 2 * _tmp1 + HWRoute2 |
| | p_y4 = y_ptr_base + 3 * _tmp1 + HWRoute3 |
| | else: |
| | p_y1 = y_ptr_base + HWRoute0 * 4 * DC |
| | p_y2 = y_ptr_base + DC + HWRoute1 * 4 * DC |
| | p_y3 = y_ptr_base + 2 * DC + HWRoute2 * 4 * DC |
| | p_y4 = y_ptr_base + 3 * DC + HWRoute3 * 4 * DC |
| | |
| | if onebyone == 0: |
| | x_ptr_base = x + i_b * _tmp1 + (i_c * BC * DH * DW if x_layout == 0 else i_c * BC) |
| | if x_layout == 0: |
| | p_x = x_ptr_base + HWRoute0 |
| | else: |
| | p_x = x_ptr_base + HWRoute0 * DC |
| |
|
| | if operation == 0: |
| | for idxc in range(_for_C): |
| | _idx_x = idxc * DH * DW if x_layout == 0 else idxc |
| | _idx_y = idxc * DH * DW if y_layout == 0 else idxc |
| | _x = tl.load(p_x + _idx_x, mask=_mask_hw) |
| | tl.store(p_y1 + _idx_y, _x, mask=_mask_hw) |
| | tl.store(p_y2 + _idx_y, _x, mask=_mask_hw) |
| | tl.store(p_y3 + _idx_y, _x, mask=_mask_hw) |
| | tl.store(p_y4 + _idx_y, _x, mask=_mask_hw) |
| | elif operation == 1: |
| | for idxc in range(_for_C): |
| | _idx_x = idxc * DH * DW if x_layout == 0 else idxc |
| | _idx_y = idxc * DH * DW if y_layout == 0 else idxc |
| | _y1 = tl.load(p_y1 + _idx_y, mask=_mask_hw) |
| | _y2 = tl.load(p_y2 + _idx_y, mask=_mask_hw) |
| | _y3 = tl.load(p_y3 + _idx_y, mask=_mask_hw) |
| | _y4 = tl.load(p_y4 + _idx_y, mask=_mask_hw) |
| | tl.store(p_x + _idx_x, _y1 + _y2 + _y3 + _y4, mask=_mask_hw) |
| |
|
| | else: |
| | x_ptr_base = x + i_b * 4 * _tmp1 + (i_c * BC * DH * DW if x_layout == 0 else i_c * BC) |
| | if x_layout == 0: |
| | p_x1 = x_ptr_base + HWRoute0 |
| | p_x2 = p_x1 + _tmp1 |
| | p_x3 = p_x2 + _tmp1 |
| | p_x4 = p_x3 + _tmp1 |
| | else: |
| | p_x1 = x_ptr_base + HWRoute0 * 4 * DC |
| | p_x2 = p_x1 + DC |
| | p_x3 = p_x2 + DC |
| | p_x4 = p_x3 + DC |
| | |
| | if operation == 0: |
| | for idxc in range(_for_C): |
| | _idx_x = idxc * DH * DW if x_layout == 0 else idxc |
| | _idx_y = idxc * DH * DW if y_layout == 0 else idxc |
| | tl.store(p_y1 + _idx_y, tl.load(p_x1 + _idx_x, mask=_mask_hw), mask=_mask_hw) |
| | tl.store(p_y2 + _idx_y, tl.load(p_x2 + _idx_x, mask=_mask_hw), mask=_mask_hw) |
| | tl.store(p_y3 + _idx_y, tl.load(p_x3 + _idx_x, mask=_mask_hw), mask=_mask_hw) |
| | tl.store(p_y4 + _idx_y, tl.load(p_x4 + _idx_x, mask=_mask_hw), mask=_mask_hw) |
| | else: |
| | for idxc in range(_for_C): |
| | _idx_x = idxc * DH * DW if x_layout == 0 else idxc |
| | _idx_y = idxc * DH * DW if y_layout == 0 else idxc |
| | tl.store(p_x1 + _idx_x, tl.load(p_y1 + _idx_y), mask=_mask_hw) |
| | tl.store(p_x2 + _idx_x, tl.load(p_y2 + _idx_y), mask=_mask_hw) |
| | tl.store(p_x3 + _idx_x, tl.load(p_y3 + _idx_y), mask=_mask_hw) |
| | tl.store(p_x4 + _idx_x, tl.load(p_y4 + _idx_y), mask=_mask_hw) |
| |
|
| |
|
| | class CrossScanTritonF(torch.autograd.Function): |
| | @staticmethod |
| | def forward(ctx, x: torch.Tensor, in_channel_first=True, out_channel_first=True, one_by_one=False, scans=0): |
| | if one_by_one: |
| | if in_channel_first: |
| | B, _, C, H, W = x.shape |
| | else: |
| | B, H, W, _, C = x.shape |
| | else: |
| | if in_channel_first: |
| | B, C, H, W = x.shape |
| | else: |
| | B, H, W, C = x.shape |
| | B, C, H, W = int(B), int(C), int(H), int(W) |
| | BC, BH, BW = 1, 32, 32 |
| | NH, NW, NC = triton.cdiv(H, BH), triton.cdiv(W, BW), triton.cdiv(C, BC) |
| | |
| | ctx.in_channel_first = in_channel_first |
| | ctx.out_channel_first = out_channel_first |
| | ctx.one_by_one = one_by_one |
| | ctx.scans = scans |
| | ctx.shape = (B, C, H, W) |
| | ctx.triton_shape = (BC, BH, BW, NC, NH, NW) |
| |
|
| | y = x.new_empty((B, 4, C, H * W)) if out_channel_first else x.new_empty((B, H * W, 4, C)) |
| | triton_cross_scan_flex[(NH * NW, NC, B)]( |
| | x.contiguous(), y, |
| | (0 if in_channel_first else 1), (0 if out_channel_first else 1), 0, (0 if not one_by_one else 1), scans, |
| | BC, BH, BW, C, H, W, NH, NW |
| | ) |
| | return y |
| | |
| | @staticmethod |
| | def backward(ctx, y: torch.Tensor): |
| | in_channel_first = ctx.in_channel_first |
| | out_channel_first = ctx.out_channel_first |
| | one_by_one = ctx.one_by_one |
| | scans = ctx.scans |
| | B, C, H, W = ctx.shape |
| | BC, BH, BW, NC, NH, NW = ctx.triton_shape |
| | if one_by_one: |
| | x = y.new_empty((B, 4, C, H, W)) if in_channel_first else y.new_empty((B, H, W, 4, C)) |
| | else: |
| | x = y.new_empty((B, C, H, W)) if in_channel_first else y.new_empty((B, H, W, C)) |
| | |
| | triton_cross_scan_flex[(NH * NW, NC, B)]( |
| | x, y.contiguous(), |
| | (0 if in_channel_first else 1), (0 if out_channel_first else 1), 1, (0 if not one_by_one else 1), scans, |
| | BC, BH, BW, C, H, W, NH, NW |
| | ) |
| | return x, None, None, None, None |
| |
|
| |
|
| | class CrossMergeTritonF(torch.autograd.Function): |
| | @staticmethod |
| | def forward(ctx, y: torch.Tensor, in_channel_first=True, out_channel_first=True, one_by_one=False, scans=0): |
| | if out_channel_first: |
| | B, _, C, H, W = y.shape |
| | else: |
| | B, H, W, _, C = y.shape |
| | B, C, H, W = int(B), int(C), int(H), int(W) |
| | BC, BH, BW = 1, 32, 32 |
| | NH, NW, NC = triton.cdiv(H, BH), triton.cdiv(W, BW), triton.cdiv(C, BC) |
| | ctx.in_channel_first = in_channel_first |
| | ctx.out_channel_first = out_channel_first |
| | ctx.one_by_one = one_by_one |
| | ctx.scans = scans |
| | ctx.shape = (B, C, H, W) |
| | ctx.triton_shape = (BC, BH, BW, NC, NH, NW) |
| | if one_by_one: |
| | x = y.new_empty((B, 4, C, H * W)) if in_channel_first else y.new_empty((B, H * W, 4, C)) |
| | else: |
| | x = y.new_empty((B, C, H * W)) if in_channel_first else y.new_empty((B, H * W, C)) |
| | triton_cross_scan_flex[(NH * NW, NC, B)]( |
| | x, y.contiguous(), |
| | (0 if in_channel_first else 1), (0 if out_channel_first else 1), 1, (0 if not one_by_one else 1), scans, |
| | BC, BH, BW, C, H, W, NH, NW |
| | ) |
| | return x |
| | |
| | @staticmethod |
| | def backward(ctx, x: torch.Tensor): |
| | in_channel_first = ctx.in_channel_first |
| | out_channel_first = ctx.out_channel_first |
| | one_by_one = ctx.one_by_one |
| | scans = ctx.scans |
| | B, C, H, W = ctx.shape |
| | BC, BH, BW, NC, NH, NW = ctx.triton_shape |
| | y = x.new_empty((B, 4, C, H, W)) if out_channel_first else x.new_empty((B, H, W, 4, C)) |
| | triton_cross_scan_flex[(NH * NW, NC, B)]( |
| | x.contiguous(), y, |
| | (0 if in_channel_first else 1), (0 if out_channel_first else 1), 0, (0 if not one_by_one else 1), scans, |
| | BC, BH, BW, C, H, W, NH, NW |
| | ) |
| | return y, None, None, None, None, None |
| |
|
| |
|
| | |
| | def cross_scan_fn(x: torch.Tensor, in_channel_first=True, out_channel_first=True, one_by_one=False, scans=0, force_torch=False): |
| | |
| | |
| | |
| | CSF = CrossScanTritonF if WITH_TRITON and x.is_cuda and (not force_torch) else CrossScanF |
| | if x.is_cuda: |
| | with torch.cuda.device(x.device): |
| | return CSF.apply(x, in_channel_first, out_channel_first, one_by_one, scans) |
| | else: |
| | return CrossScanF.apply(x, in_channel_first, out_channel_first, one_by_one, scans) |
| |
|
| |
|
| | |
| | def cross_merge_fn(y: torch.Tensor, in_channel_first=True, out_channel_first=True, one_by_one=False, scans=0, force_torch=False): |
| | |
| | |
| | |
| | CMF = CrossMergeTritonF if WITH_TRITON and y.is_cuda and (not force_torch) else CrossMergeF |
| | if y.is_cuda: |
| | with torch.cuda.device(y.device): |
| | return CMF.apply(y, in_channel_first, out_channel_first, one_by_one, scans) |
| | else: |
| | return CrossMergeF.apply(y, in_channel_first, out_channel_first, one_by_one, scans) |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | WITH_SELECTIVESCAN_MAMBA = True |
| | try: |
| | import selective_scan_cuda |
| | except ImportError: |
| | WITH_SELECTIVESCAN_MAMBA = False |
| |
|
| |
|
| | def selective_scan_torch( |
| | u: torch.Tensor, |
| | delta: torch.Tensor, |
| | A: torch.Tensor, |
| | B: torch.Tensor, |
| | C: torch.Tensor, |
| | D: torch.Tensor = None, |
| | delta_bias: torch.Tensor = None, |
| | delta_softplus=True, |
| | oflex=True, |
| | *args, |
| | **kwargs |
| | ): |
| | dtype_in = u.dtype |
| | Batch, K, N, L = B.shape |
| | KCdim = u.shape[1] |
| | Cdim = int(KCdim / K) |
| | assert u.shape == (Batch, KCdim, L) |
| | assert delta.shape == (Batch, KCdim, L) |
| | assert A.shape == (KCdim, N) |
| | assert C.shape == B.shape |
| |
|
| | if delta_bias is not None: |
| | delta = delta + delta_bias[..., None] |
| | if delta_softplus: |
| | delta = torch.nn.functional.softplus(delta) |
| | |
| | u, delta, A, B, C = u.float(), delta.float(), A.float(), B.float(), C.float() |
| | B = B.view(Batch, K, 1, N, L).repeat(1, 1, Cdim, 1, 1).view(Batch, KCdim, N, L) |
| | C = C.view(Batch, K, 1, N, L).repeat(1, 1, Cdim, 1, 1).view(Batch, KCdim, N, L) |
| | deltaA = torch.exp(torch.einsum('bdl,dn->bdln', delta, A)) |
| | deltaB_u = torch.einsum('bdl,bdnl,bdl->bdln', delta, B, u) |
| | |
| | if True: |
| | x = A.new_zeros((Batch, KCdim, N)) |
| | ys = [] |
| | for i in range(L): |
| | x = deltaA[:, :, i, :] * x + deltaB_u[:, :, i, :] |
| | y = torch.einsum('bdn,bdn->bd', x, C[:, :, :, i]) |
| | ys.append(y) |
| | y = torch.stack(ys, dim=2) |
| | |
| | out = y if D is None else y + u * D.unsqueeze(-1) |
| | return out if oflex else out.to(dtype=dtype_in) |
| |
|
| |
|
| | class SelectiveScanCuda(torch.autograd.Function): |
| | @staticmethod |
| | @torch.cuda.amp.custom_fwd |
| | def forward(ctx, u, delta, A, B, C, D=None, delta_bias=None, delta_softplus=False, oflex=True, backend=None): |
| | ctx.delta_softplus = delta_softplus |
| | |
| | |
| | backend = "mamba" if WITH_SELECTIVESCAN_MAMBA and (backend is None) else backend |
| | ctx.backend = backend |
| | if backend == "oflex": |
| | out, x, *rest = selective_scan_cuda_oflex.fwd(u, delta, A, B, C, D, delta_bias, delta_softplus, 1, oflex) |
| | elif backend == "mamba": |
| | out, x, *rest = selective_scan_cuda.fwd(u, delta, A, B, C, D, None, delta_bias, delta_softplus) |
| | ctx.save_for_backward(u, delta, A, B, C, D, delta_bias, x) |
| | return out |
| | |
| | @staticmethod |
| | @torch.cuda.amp.custom_bwd |
| | def backward(ctx, dout, *args): |
| | u, delta, A, B, C, D, delta_bias, x = ctx.saved_tensors |
| | backend = ctx.backend |
| | if dout.stride(-1) != 1: |
| | dout = dout.contiguous() |
| | if backend == "oflex": |
| | du, ddelta, dA, dB, dC, dD, ddelta_bias, *rest = selective_scan_cuda_oflex.bwd( |
| | u, delta, A, B, C, D, delta_bias, dout, x, ctx.delta_softplus, 1 |
| | ) |
| | elif backend == "mamba": |
| | du, ddelta, dA, dB, dC, dD, ddelta_bias, *rest = selective_scan_cuda.bwd( |
| | u, delta, A, B, C, D, None, delta_bias, dout, x, None, None, ctx.delta_softplus, |
| | False |
| | ) |
| | return du, ddelta, dA, dB, dC, dD, ddelta_bias, None, None, None |
| |
|
| |
|
| | def selective_scan_fn( |
| | u: torch.Tensor, |
| | delta: torch.Tensor, |
| | A: torch.Tensor, |
| | B: torch.Tensor, |
| | C: torch.Tensor, |
| | D: torch.Tensor = None, |
| | delta_bias: torch.Tensor = None, |
| | delta_softplus=True, |
| | oflex=True, |
| | backend=None, |
| | ): |
| | fn = selective_scan_torch if backend == "torch" or (not WITH_SELECTIVESCAN_MAMBA) else SelectiveScanCuda.apply |
| | return fn(u, delta, A, B, C, D, delta_bias, delta_softplus, oflex, backend) |
| |
|
| | |
| | |
| | |
| |
|
| | class DASSLinear2d(nn.Linear): |
| | def __init__(self, *args, groups=1, **kwargs): |
| | nn.Linear.__init__(self, *args, **kwargs) |
| | self.groups = groups |
| | |
| | def forward(self, x: torch.Tensor): |
| | if len(x.shape) == 4: |
| | return F.conv2d(x, self.weight[:, :, None, None], self.bias, groups=self.groups) |
| | elif len(x.shape) == 3: |
| | return F.conv1d(x, self.weight[:, :, None], self.bias, groups=self.groups) |
| |
|
| | def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): |
| | self_state_dict = self.state_dict() |
| | load_state_dict_keys = list(state_dict.keys()) |
| | if prefix + "weight" in load_state_dict_keys: |
| | state_dict[prefix + "weight"] = state_dict[prefix + "weight"].view_as(self_state_dict["weight"]) |
| | return super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) |
| |
|
| |
|
| | class DASSLayerNorm2d(nn.LayerNorm): |
| | def __init__(self, *args, **kwargs): |
| | nn.LayerNorm.__init__(self, *args, **kwargs) |
| |
|
| | def forward(self, x: torch.Tensor): |
| | x = x.permute(0, 2, 3, 1) |
| | x = nn.LayerNorm.forward(self, x) |
| | x = x.permute(0, 3, 1, 2) |
| | return x |
| |
|
| |
|
| | class DASSPatchEmbeddings(nn.Module): |
| | """ |
| | This class turns `input_values` into the initial `hidden_states` (patch embeddings) of shape `(batch_size, |
| | seq_length, hidden_size)` to be consumed by a State-space model. |
| | """ |
| |
|
| | def __init__(self, patch_size=4,embed_dim=96): |
| | super().__init__() |
| |
|
| | stride = patch_size // 2 |
| | kernel_size = stride + 1 |
| | padding = 1 |
| |
|
| | self.projection = nn.Sequential( |
| | nn.Conv2d(1, embed_dim // 2, kernel_size=kernel_size, stride=stride, padding=padding), |
| | DASSLayerNorm2d(embed_dim // 2), |
| | nn.GELU(), |
| | nn.Conv2d(embed_dim // 2, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding), |
| | DASSLayerNorm2d(embed_dim), |
| | ) |
| |
|
| | def forward(self, x: torch.Tensor) -> torch.Tensor: |
| | x = x.unsqueeze(1) |
| | x = x.transpose(2, 3) |
| | x = self.projection(x) |
| | return x |
| |
|
| |
|
| | class DASSDowsample(nn.Module): |
| | """ |
| | This class downsamples the input tensor using a convolutional layer followed by a layer normalization. |
| | """ |
| | def __init__(self, dim, out_dim, use_norm=True): |
| | super().__init__() |
| | self.down = nn.Conv2d(dim, out_dim, kernel_size=3, stride=2, padding=1) |
| | self.norm = DASSLayerNorm2d(out_dim) if use_norm else nn.Identity() |
| |
|
| | def forward(self, x): |
| | x = self.down(x) |
| | x = self.norm(x) |
| | return x |
| |
|
| |
|
| | class DASSMlp(nn.Module): |
| | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): |
| | super().__init__() |
| | out_features = out_features or in_features |
| | hidden_features = hidden_features or in_features |
| | self.fc1 = DASSLinear2d(in_features, hidden_features) |
| | self.act = act_layer() |
| | self.fc2 = DASSLinear2d(hidden_features, out_features) |
| | self.drop = nn.Dropout(drop) |
| |
|
| | def forward(self, x): |
| | x = self.fc1(x) |
| | x = self.act(x) |
| | x = self.drop(x) |
| | x = self.fc2(x) |
| | x = self.drop(x) |
| | return x |
| |
|
| |
|
| | class SS2D(nn.Module): |
| | def __init__( |
| | self, |
| | |
| | d_model=96, |
| | d_state=16, |
| | ssm_ratio=2.0, |
| | dt_rank="auto", |
| | act_layer=nn.SiLU, |
| | |
| | d_conv=3, |
| | conv_bias=True, |
| | |
| | dropout=0.0, |
| | bias=False, |
| | |
| | dt_min=0.001, |
| | dt_max=0.1, |
| | dt_init="random", |
| | dt_scale=1.0, |
| | dt_init_floor=1e-4, |
| | |
| | |
| | **kwargs, |
| | ): |
| | super().__init__() |
| | self.k_group = 4 |
| | self.d_model = int(d_model) |
| | self.d_state = int(d_state) |
| | self.d_inner = int(ssm_ratio * d_model) |
| | self.dt_rank = int(math.ceil(self.d_model / 16) if dt_rank == "auto" else dt_rank) |
| | self.forward_core = partial(self.forward_corev2, force_fp32=False, no_einsum=True) |
| | self.with_dconv = d_conv > 1 |
| |
|
| | |
| | self.in_proj = DASSLinear2d(self.d_model, self.d_inner, bias=bias) |
| | self.act: nn.Module = act_layer() |
| |
|
| | |
| | if self.with_dconv: |
| | self.conv2d = nn.Conv2d( |
| | in_channels=self.d_inner, |
| | out_channels=self.d_inner, |
| | groups=self.d_inner, |
| | bias=conv_bias, |
| | kernel_size=d_conv, |
| | padding=(d_conv - 1) // 2, |
| | ) |
| |
|
| | |
| | self.x_proj = DASSLinear2d(self.d_inner, self.k_group * (self.dt_rank + self.d_state * 2), groups=self.k_group, bias=False) |
| | self.dt_projs = DASSLinear2d(self.dt_rank, self.k_group * self.d_inner, groups=self.k_group, bias=False) |
| |
|
| | |
| | self.out_proj = DASSLinear2d(self.d_inner, self.d_model, bias=bias) |
| | self.dropout = nn.Dropout(dropout) if dropout > 0. else nn.Identity() |
| |
|
| | |
| | self.A_logs, self.Ds, self.dt_projs_weight, self.dt_projs_bias = self.init_dt_A_D( |
| | self.d_state, self.dt_rank, self.d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor, k_group=self.k_group, |
| | ) |
| | self.dt_projs.weight.data = self.dt_projs_weight.data.view(self.dt_projs.weight.shape) |
| | |
| | del self.dt_projs_weight |
| | |
| | |
| | self.out_norm = DASSLayerNorm2d(self.d_inner) |
| |
|
| | @staticmethod |
| | def dt_init(dt_rank, d_inner, dt_scale=1.0, dt_init="random", dt_min=0.001, dt_max=0.1, dt_init_floor=1e-4): |
| | dt_proj = nn.Linear(dt_rank, d_inner, bias=True) |
| |
|
| | dt_init_std = dt_rank**-0.5 * dt_scale |
| | if dt_init == "constant": |
| | nn.init.constant_(dt_proj.weight, dt_init_std) |
| | elif dt_init == "random": |
| | nn.init.uniform_(dt_proj.weight, -dt_init_std, dt_init_std) |
| | else: |
| | raise NotImplementedError |
| |
|
| | dt = torch.exp( |
| | torch.rand(d_inner) * (math.log(dt_max) - math.log(dt_min)) |
| | + math.log(dt_min) |
| | ).clamp(min=dt_init_floor) |
| |
|
| | inv_dt = dt + torch.log(-torch.expm1(-dt)) |
| | with torch.no_grad(): |
| | dt_proj.bias.copy_(inv_dt) |
| | |
| | return dt_proj |
| |
|
| | @staticmethod |
| | def A_log_init(d_state, d_inner, copies=-1, device=None, merge=True): |
| | A = torch.arange(1, d_state + 1, dtype=torch.float32, device=device).view(1, -1).repeat(d_inner, 1).contiguous() |
| | A_log = torch.log(A) |
| | if copies > 0: |
| | A_log = A_log[None].repeat(copies, 1, 1).contiguous() |
| | if merge: |
| | A_log = A_log.flatten(0, 1) |
| | A_log = nn.Parameter(A_log) |
| | A_log._no_weight_decay = True |
| | return A_log |
| |
|
| | @staticmethod |
| | def D_init(d_inner, copies=-1, device=None, merge=True): |
| | D = torch.ones(d_inner, device=device) |
| | if copies > 0: |
| | D = D[None].repeat(copies, 1).contiguous() |
| | if merge: |
| | D = D.flatten(0, 1) |
| | D = nn.Parameter(D) |
| | D._no_weight_decay = True |
| | return D |
| |
|
| | @classmethod |
| | def init_dt_A_D(cls, d_state, dt_rank, d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor, k_group=4): |
| | dt_projs = [ |
| | cls.dt_init(dt_rank, d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor) |
| | for _ in range(k_group) |
| | ] |
| | dt_projs_weight = nn.Parameter(torch.stack([t.weight for t in dt_projs], dim=0)) |
| | dt_projs_bias = nn.Parameter(torch.stack([t.bias for t in dt_projs], dim=0)) |
| | del dt_projs |
| | |
| | A_logs = cls.A_log_init(d_state, d_inner, copies=k_group, merge=True) |
| | Ds = cls.D_init(d_inner, copies=k_group, merge=True) |
| | return A_logs, Ds, dt_projs_weight, dt_projs_bias |
| |
|
| | def forward_corev2( |
| | self, |
| | x: torch.Tensor, |
| | force_fp32=False, |
| | no_einsum=True, |
| | ): |
| | B, D, H, W = x.shape |
| | N = self.d_state |
| | L = H * W |
| |
|
| | xs = cross_scan_fn(x, in_channel_first=True, out_channel_first=True) |
| | x_dbl = self.x_proj(xs.view(B, -1, L)) |
| | dts, Bs, Cs = torch.split(x_dbl.view(B, self.k_group, -1, L), [self.dt_rank, N, N], dim=2) |
| | dts = dts.contiguous().view(B, -1, L) |
| | dts = self.dt_projs(dts) |
| |
|
| | xs = xs.view(B, -1, L) |
| | dts = dts.contiguous().view(B, -1, L) |
| | As = -self.A_logs.to(torch.float32).exp() |
| | Ds = self.Ds.to(torch.float32) |
| | Bs = Bs.contiguous().view(B, self.k_group, N, L) |
| | Cs = Cs.contiguous().view(B, self.k_group, N, L) |
| | delta_bias = self.dt_projs_bias.view(-1).to(torch.float32) |
| | |
| | ys = selective_scan_fn( |
| | xs, dts, As, Bs, Cs, Ds, delta_bias, delta_softplus=True, backend="mamba" |
| | ).view(B, self.k_group, -1, H, W) |
| | |
| | y = cross_merge_fn(ys, in_channel_first=True, out_channel_first=True) |
| | y = y.view(B, -1, H, W) |
| | y = self.out_norm(y) |
| | return y.to(x.dtype) |
| |
|
| | def forward(self, x: torch.Tensor): |
| | x = self.in_proj(x) |
| | x = self.conv2d(x) |
| | |
| | x = self.act(x) |
| | y = self.forward_core(x) |
| | |
| | out = self.dropout(self.out_proj(y)) |
| | return out |
| |
|
| |
|
| | class VSSBlock(nn.Module): |
| | def __init__( |
| | self, |
| | hidden_dim: int = 0, |
| | drop_path: float = 0, |
| | ssm_d_state: int = 1, |
| | ssm_ratio=1.0, |
| | ssm_dt_rank: Any = "auto", |
| | ssm_act_layer=nn.SiLU, |
| | ssm_conv: int = 3, |
| | ssm_conv_bias=False, |
| | ssm_drop_rate: float = 0, |
| | mlp_ratio=4.0, |
| | mlp_act_layer=nn.GELU, |
| | mlp_drop_rate: float = 0.0, |
| | use_checkpoint: bool = False, |
| | post_norm: bool = False, |
| | **kwargs, |
| | ): |
| | super().__init__() |
| | self.ssm_branch = ssm_ratio > 0 |
| | self.mlp_branch = mlp_ratio > 0 |
| | self.use_checkpoint = use_checkpoint |
| | self.post_norm = post_norm |
| |
|
| | if self.ssm_branch: |
| | self.norm = DASSLayerNorm2d(hidden_dim) |
| | self.op = SS2D( |
| | d_model=hidden_dim, |
| | d_state=ssm_d_state, |
| | ssm_ratio=ssm_ratio, |
| | dt_rank=ssm_dt_rank, |
| | act_layer=ssm_act_layer, |
| | d_conv=ssm_conv, |
| | conv_bias=ssm_conv_bias, |
| | dropout=ssm_drop_rate, |
| | ) |
| | |
| | self.drop_path = DropPath(drop_path) |
| | |
| | if self.mlp_branch: |
| | self.norm2 = DASSLayerNorm2d(hidden_dim) |
| | mlp_hidden_dim = int(hidden_dim * mlp_ratio) |
| | self.mlp = DASSMlp(in_features=hidden_dim, hidden_features=mlp_hidden_dim, act_layer=mlp_act_layer, drop=mlp_drop_rate) |
| |
|
| | def _forward(self, input: torch.Tensor): |
| | x = input |
| | if self.ssm_branch: |
| | if self.post_norm: |
| | x = x + self.drop_path(self.norm(self.op(x))) |
| | else: |
| | x = x + self.drop_path(self.op(self.norm(x))) |
| | if self.mlp_branch: |
| | if self.post_norm: |
| | x = x + self.drop_path(self.norm2(self.mlp(x))) |
| | else: |
| | x = x + self.drop_path(self.mlp(self.norm2(x))) |
| | return x |
| |
|
| | def forward(self, input: torch.Tensor): |
| | if self.use_checkpoint: |
| | return checkpoint.checkpoint(self._forward, input) |
| | else: |
| | return self._forward(input) |
| |
|
| | class DASSLayer(nn.Module): |
| |
|
| | def __init__( |
| | self, |
| | input_dim, |
| | depth, |
| | drop_path=0.0, |
| | norm_layer=DASSLayerNorm2d, |
| | downsample=nn.Identity(), |
| | use_checkpoint=False, |
| | **kwargs, |
| | ): |
| | super().__init__() |
| | self.input_dim = input_dim |
| | self.use_checkpoint = use_checkpoint |
| |
|
| | self.blocks = nn.ModuleList() |
| | for i in range(depth): |
| | self.blocks.append( |
| | VSSBlock(hidden_dim=input_dim, |
| | drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, |
| | norm_layer=norm_layer,use_checkpoint=use_checkpoint,**kwargs, |
| | ) |
| | ) |
| | |
| | self.downsample = downsample |
| |
|
| | def forward(self, x): |
| | for block in self.blocks: |
| | x = block(x) |
| |
|
| | x = self.downsample(x) |
| | return x |
| |
|
| | class DASSPreTrainedModel(PreTrainedModel): |
| | """ |
| | An abstract class to handle weights initialization and |
| | a simple interface for downloading and loading pretrained models. |
| | """ |
| |
|
| | config_class = DASSConfig |
| | base_model_prefix = "dass" |
| | supports_gradient_checkpointing = False |
| |
|
| | def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None: |
| | """Initialize the weights""" |
| | if isinstance(module, nn.Linear): |
| | trunc_normal_(module.weight, std=0.02) |
| | if isinstance(module, nn.Linear) and module.bias is not None: |
| | nn.init.constant_(module.bias, 0) |
| | elif isinstance(module, nn.LayerNorm): |
| | nn.init.constant_(module.bias, 0) |
| | nn.init.constant_(module.weight, 1.0) |
| |
|
| |
|
| | class DASSModel(DASSPreTrainedModel): |
| | def __init__(self, config): |
| | super().__init__(config) |
| | self.config = config |
| |
|
| | dims = config.dims |
| | if isinstance(dims, int): |
| | dims = [int(dims * 2**i_layer) for i_layer in range(self.num_layers)] |
| |
|
| | self.dims = dims |
| | self.patch_embeddings = DASSPatchEmbeddings(patch_size=config.patch_size, |
| | embed_dim=dims[0]) |
| | |
| | self.num_layers = len(config.depths) |
| | dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, sum(config.depths))] |
| | self.num_features = dims[-1] |
| |
|
| | self.layers = nn.ModuleList() |
| | for i in range(self.num_layers): |
| | layer = DASSLayer( |
| | input_dim=self.dims[i], |
| | depth=config.depths[i], |
| | drop_path=dpr[sum(config.depths[:i]):sum(config.depths[:i+1])], |
| | downsample=DASSDowsample(self.dims[i], self.dims[i+1]) if i < self.num_layers - 1 else nn.Identity(), |
| | use_checkpoint=config.use_checkpoint, |
| | ) |
| | self.layers.append(layer) |
| | |
| | self.norm = DASSLayerNorm2d(self.num_features) |
| | self.avgpool = nn.AdaptiveAvgPool2d(1) |
| |
|
| | def get_input_embeddings(self) -> DASSPatchEmbeddings: |
| | return self.patch_embeddings |
| | |
| | def forward(self, input_values: torch.Tensor): |
| | x = self.patch_embeddings(input_values) |
| | for layer in self.layers: |
| | x = layer(x) |
| | x = self.norm(x) |
| | x = self.avgpool(x).flatten(1) |
| | return x |
| |
|
| |
|
| | class DASSForAudioClassification(DASSPreTrainedModel): |
| | def __init__(self, config): |
| | super().__init__(config) |
| |
|
| | self.num_classes = config.num_classes |
| | self.dass = DASSModel(config) |
| | self.head = nn.Linear(self.dass.num_features, self.num_classes) if self.num_classes > 0 else nn.Identity() |
| |
|
| | |
| | self.post_init() |
| |
|
| | def forward( |
| | self, |
| | input_values: Optional[torch.Tensor] = None, |
| | labels: Optional[torch.Tensor] = None, |
| | return_dict: Optional[bool] = None, |
| | ): |
| |
|
| | outputs = self.dass( |
| | input_values, |
| | ) |
| |
|
| | logits = self.head(outputs) |
| |
|
| | loss = None |
| | if labels is not None: |
| | labels = labels.to(logits.device) |
| | if self.config.loss_type == "ce": |
| | loss_fct = CrossEntropyLoss() |
| | loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) |
| | elif self.config.problem_type == "bce": |
| | loss_fct = BCEWithLogitsLoss() |
| | loss = loss_fct(logits, labels) |
| |
|
| | if return_dict: |
| | output = (logits,) + (outputs,) |
| | return ((loss,) + output) if loss is not None else output |
| |
|
| | return SequenceClassifierOutput( |
| | loss=loss, |
| | logits=logits, |
| | hidden_states=outputs, |
| | ) |
| |
|
| | __all__ = [ |
| | "DASSModel", |
| | "DASSPreTrainedModel", |
| | "DASSForAudioClassification", |
| | ] |
| |
|