pytorch 类型错误:relu():参数'input'(位置1)必须是Tensor,而不是元组,我相信这是因为我有一个LSTM层

kzmpq1sx  于 2022-11-23  发布在  其他
关注(0)|答案(1)|浏览(602)

我相信错误是因为我有一个LSTM层。我如何修改代码,使它将工作正常?任何帮助?

py49o6xq

py49o6xq1#

LSTM层移出顺序层。LSTM返回output, (hn, cn)的元组,其中hn, cn是最后的隐藏状态。
例如,您的init函数将包含如下内容

class module(nn.Module):
    def __init__(self):
        super(nn.Module, self).__init__()
        self.lstm = nn.LSTM(...)
        self.seq = nn.Sequential(...)

转发函数将为

def forward(self, x):
    lstm_out= self.lstm(x)
    out = self.seq(lstm_out[0])
    return out

相关问题