pytorch cuDNN错误:CUDNN_STATUS_BAD_PARAM,有人能解释一下为什么我会得到这个错误,以及我如何纠正它吗?

c3frrgcw  于 2022-12-04  发布在  其他
关注(0)|答案(4)|浏览(472)

我正在尝试用Pytorch实现一个字符LSTM。但是我得到了cudnn_status_bad_params错误。这是训练循环。我在output = model(input_seq)行上得到了错误。

for epoch in tqdm(range(epochs)):
  for i in range(len(seq)//batch_size):
   sidx = i*batch_size
   eidx = sidx + batch_size
   x = seq[sidx:eidx]
   x = torch.tensor(x).cuda()
   input_seq =torch.nn.utils.rnn.pack_padded_sequence(x,seq_lengths,batch_first = True)
   y = out_seq[sidx:eidx]
   output = model(input_seq)
   loss = criterion(output,y)
   loss.backward()
   optimizer.step()
/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
    487             result = self._slow_forward(*input, **kwargs)
    488         else:
--> 489             result = self.forward(*input, **kwargs)
    490         for hook in self._forward_hooks.values():
    491             hook_result = hook(self, input, result)   
/usr/local/lib/python3.6/dist-packages/torch/nn/modules/rnn.py in forward(self, input, hx)
    180         else:
    181             result = _impl(input, batch_sizes, hx, self._flat_weights, self.bias,
--> 182                            self.num_layers, self.dropout, self.training, self.bidirectional)
    183         output = result[0]
    184         hidden = result[1:] if self.mode == 'LSTM' else result[1]

 RuntimeError: cuDNN error: CUDNN_STATUS_BAD_PARAM
w8f9ii69

w8f9ii691#

我本想对普尔提的回答发表评论,但我不能,所以我在这里为后代加上这句话:
我在CPU上运行模型,我的错误被升级为另一个半帮助错误,我在网上找不到解决方案:

RuntimeError: could not create a descriptor for a dilated convolution forward propagation primitive

对我来说,这是一个conv层,错误地定义为dilation=0而不是1。因此,根据原始错误(CUDNN_STATUS_BAD_PARAM),确保易出错层的参数有效。

oyt4ldly

oyt4ldly2#

我遇到了相同的错误。Here's the solution
您应该将输入类型从float64更改为float32,这意味着您应该键入:

input_seq = input_seq.float()
brjng4g3

brjng4g33#

我得到了同样的错误,如果你切换到CPU,你会得到一个更好的错误描述。在我的情况下,问题是在类型的输入,我给网络。我发送我猜long,而模型需要float。我做了以下的变化和代码工作。基本上切换到CPU提供更好的错误描述。

input_seq = input_seq.float().cuda()
ghhkc1vu

ghhkc1vu4#

我遇到了同样的问题,问题是 Torch ==1.6。解决方案可以在这里找到git issue。看一下。这可能也是你的解决方案。

相关问题