# Copyright (c) 2020 Johns Hopkins University (Shinji Watanabe) # 2020 Northwestern Polytechnical University (Pengcheng Guo) # 2020 Mobvoi Inc (Binbin Zhang) # 2024 Alibaba Inc (Xiang Lyu) # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Swish() activation function for Conformer.""" import torch from torch import nn, sin, pow from torch.nn import Parameter class Swish(torch.nn.Module): """Construct an Swish object.""" def forward(self, x: torch.Tensor) -> torch.Tensor: """Return Swish activation function.""" return x * torch.sigmoid(x) # Implementation adapted from https://github.com/EdwardDixon/snake under the MIT license. # LICENSE is in incl_licenses directory. class Snake(nn.Module): ''' Implementation of a sine-based periodic activation function Shape: - Input: (B, C, T) - Output: (B, C, T), same shape as the input Parameters: - alpha - trainable parameter References: - This activation function is from this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda: https://arxiv.org/abs/2006.08195 Examples: >>> a1 = snake(256) >>> x = torch.randn(256) >>> x = a1(x) ''' def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False): ''' Initialization. INPUT: - in_features: shape of the input - alpha: trainable parameter alpha is initialized to 1 by default, higher values = higher-frequency. alpha will be trained along with the rest of your model. ''' super(Snake, self).__init__() self.in_features = in_features # initialize alpha self.alpha_logscale = alpha_logscale if self.alpha_logscale: # log scale alphas initialized to zeros self.alpha = Parameter(torch.zeros(in_features) * alpha) else: # linear scale alphas initialized to ones self.alpha = Parameter(torch.ones(in_features) * alpha) self.alpha.requires_grad = alpha_trainable self.no_div_by_zero = 0.000000001 def forward(self, x): ''' Forward pass of the function. Applies the function to the input elementwise. Snake ∶= x + 1/a * sin^2 (xa) ''' alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T] if self.alpha_logscale: alpha = torch.exp(alpha) x = x + (1.0 / (alpha + self.no_div_by_zero)) * pow(sin(x * alpha), 2) return x