python Pytorch自定义激活功能?

bcs8qyzn  于 2022-12-28  发布在  Python
关注(0)|答案(3)|浏览(134)

我在Pytorch中实现自定义激活函数时遇到了问题,比如Swish。我应该如何在Pytorch中实现和使用自定义激活函数?

t8e9dugd

t8e9dugd1#

根据您的需求,有四种可能性。您需要问自己两个问题:

    • Q1)**您的激活函数是否具有可学习的参数?

如果yes,您别无选择,只能将激活函数创建为nn.Module类,因为您需要存储这些权重。
如果no,您可以自由地创建一个普通函数或类,这取决于您方便的方式。

    • Q2)**您的激活函数可以表示为现有PyTorch函数的组合吗?

如果yes,您可以简单地将其编写为现有PyTorch函数的组合,而不需要创建定义渐变的backward函数。
如果,则需要手写梯度。

    • 示例1:SiLU功能**

SiLU函数f(x) = x * sigmoid(x)没有任何学习过的权重,可以完全用现有的PyTorch函数编写,因此您可以简单地将其定义为函数:

def silu(x):
    return x * torch.sigmoid(x)

然后简单地使用它,就像你会有torch.relu或任何其他激活函数一样。

    • 示例2:具有学习斜率的SiLU**

在这种情况下,您有一个已学习的参数,即斜率,因此您需要将其创建一个类。

class LearnedSiLU(nn.Module):
    def __init__(self, slope = 1):
        super().__init__()
        self.slope = slope * torch.nn.Parameter(torch.ones(1))

    def forward(self, x):
        return self.slope * x * torch.sigmoid(x)
    • 示例3:向后**

如果您需要创建自己的梯度函数,可以查看以下示例:Pytorch: define custom function

xcitsw88

xcitsw882#

您可以编写一个定制的激活函数,如下所示(例如加权Tanh)。

class weightedTanh(nn.Module):
    def __init__(self, weights = 1):
        super().__init__()
        self.weights = weights

    def forward(self, input):
        ex = torch.exp(2*self.weights*input)
        return (ex-1)/(ex+1)

如果使用autograd兼容操作,则不必担心反向传播。

hmae6n7t

hmae6n7t3#

我编写了nn.ModuleSinActivation子类来实现sin激活函数。

class SinActivation(torch.nn.Module):
    def __init__(self):
        super(SinActivation, self).__init__()
        return
    def forward(self, x):
        return torch.sin(x)

相关问题