Spaces:
Running
on
A10G
Running
on
A10G
File size: 2,180 Bytes
de7836d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 |
import torch
import torch.nn as nn
import torch.nn.functional as F
class Hswish(nn.Module):
def __init__(self, inplace=True):
super(Hswish, self).__init__()
self.inplace = inplace
def forward(self, x):
return x * F.relu6(x + 3., inplace=self.inplace) / 6.
# out = max(0, min(1, slop*x+offset))
# paddle.fluid.layers.hard_sigmoid(x, slope=0.2, offset=0.5, name=None)
class Hsigmoid(nn.Module):
def __init__(self, inplace=True):
super(Hsigmoid, self).__init__()
self.inplace = inplace
def forward(self, x):
# torch: F.relu6(x + 3., inplace=self.inplace) / 6.
# paddle: F.relu6(1.2 * x + 3., inplace=self.inplace) / 6.
return F.relu6(1.2 * x + 3., inplace=self.inplace) / 6.
class GELU(nn.Module):
def __init__(self, inplace=True):
super(GELU, self).__init__()
self.inplace = inplace
def forward(self, x):
return torch.nn.functional.gelu(x)
class Swish(nn.Module):
def __init__(self, inplace=True):
super(Swish, self).__init__()
self.inplace = inplace
def forward(self, x):
if self.inplace:
x.mul_(torch.sigmoid(x))
return x
else:
return x*torch.sigmoid(x)
class Activation(nn.Module):
def __init__(self, act_type, inplace=True):
super(Activation, self).__init__()
act_type = act_type.lower()
if act_type == 'relu':
self.act = nn.ReLU(inplace=inplace)
elif act_type == 'relu6':
self.act = nn.ReLU6(inplace=inplace)
elif act_type == 'sigmoid':
raise NotImplementedError
elif act_type == 'hard_sigmoid':
self.act = Hsigmoid(inplace)
elif act_type == 'hard_swish':
self.act = Hswish(inplace=inplace)
elif act_type == 'leakyrelu':
self.act = nn.LeakyReLU(inplace=inplace)
elif act_type == 'gelu':
self.act = GELU(inplace=inplace)
elif act_type == 'swish':
self.act = Swish(inplace=inplace)
else:
raise NotImplementedError
def forward(self, inputs):
return self.act(inputs) |