Simplify swish for readability
This commit is contained in:
parent
0533150194
commit
6d021ec949
@ -1,20 +1,6 @@
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|
||||||
class SwishFunction(torch.autograd.Function):
|
|
||||||
@staticmethod
|
|
||||||
def forward(ctx, input):
|
|
||||||
result = input * torch.sigmoid(input)
|
|
||||||
ctx.save_for_backward(input)
|
|
||||||
return result
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def backward(ctx, grad_output):
|
|
||||||
input, = ctx.saved_variables
|
|
||||||
sigmoid = torch.sigmoid(input)
|
|
||||||
return grad_output * (sigmoid * (1 + input * (1 - sigmoid)))
|
|
||||||
|
|
||||||
|
|
||||||
class Swish(torch.nn.Module):
|
class Swish(torch.nn.Module):
|
||||||
def forward(self, input):
|
def forward(self, x):
|
||||||
return SwishFunction.apply(input)
|
return x * torch.sigmoid(x)
|
||||||
|
Loading…
Reference in New Issue
Block a user