Spaces:
Running
on
Zero
Running
on
Zero
| # Copyright (c) 2023 Amphion. | |
| # | |
| # This source code is licensed under the MIT license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| from torch.autograd import Function | |
| import torch | |
| from torch import nn | |
| class GradientReversal(Function): | |
| def forward(ctx, x, alpha): | |
| ctx.save_for_backward(x, alpha) | |
| return x | |
| def backward(ctx, grad_output): | |
| grad_input = None | |
| _, alpha = ctx.saved_tensors | |
| if ctx.needs_input_grad[0]: | |
| grad_input = -alpha * grad_output | |
| return grad_input, None | |
| revgrad = GradientReversal.apply | |
| class GradientReversal(nn.Module): | |
| def __init__(self, alpha): | |
| super().__init__() | |
| self.alpha = torch.tensor(alpha, requires_grad=False) | |
| def forward(self, x): | |
| return revgrad(x, self.alpha) | |