Simplify swish for readability

This commit is contained in:
Yin Li 2019-12-12 18:09:26 -05:00
parent 0533150194
commit 6d021ec949

View File

@ -1,20 +1,6 @@
import torch import torch
class SwishFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, input):
result = input * torch.sigmoid(input)
ctx.save_for_backward(input)
return result
@staticmethod
def backward(ctx, grad_output):
input, = ctx.saved_variables
sigmoid = torch.sigmoid(input)
return grad_output * (sigmoid * (1 + input * (1 - sigmoid)))
class Swish(torch.nn.Module): class Swish(torch.nn.Module):
def forward(self, input): def forward(self, x):
return SwishFunction.apply(input) return x * torch.sigmoid(x)