Simplify swish for readability
This commit is contained in:
parent
0533150194
commit
6d021ec949
@ -1,20 +1,6 @@
|
||||
import torch
|
||||
|
||||
|
||||
class SwishFunction(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, input):
|
||||
result = input * torch.sigmoid(input)
|
||||
ctx.save_for_backward(input)
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
input, = ctx.saved_variables
|
||||
sigmoid = torch.sigmoid(input)
|
||||
return grad_output * (sigmoid * (1 + input * (1 - sigmoid)))
|
||||
|
||||
|
||||
class Swish(torch.nn.Module):
|
||||
def forward(self, input):
|
||||
return SwishFunction.apply(input)
|
||||
def forward(self, x):
|
||||
return x * torch.sigmoid(x)
|
||||
|
Loading…
Reference in New Issue
Block a user