pytorch 尝试查看权重和偏差的随机值时出错

ilmyapht  于 2023-08-05  发布在  其他
关注(0)|答案(1)|浏览(112)
from torch import nn

class LinearRegressionModel(nn.Module):

    def __init__(self):
        super().__init__()
        self.weights=nn.parameter(torch.randn(1,requires_gradient=True,dtype=torch.float))
        self.bias=nn.parameter(torch.randn(1,requires_gradient=True,dtype=torch.float))   

    def forward(self,X:torch.tensor)->torch.tensor:
        return self.weights*X+self.bias

torch.manual_seed(32)
model_0=LinearRegressionModel()
model_0

字符串
这三行代码给出了如下错误
如何更正该代码?

TypeError                                 Traceback (most recent call last)
<ipython-input-25-e7c14f2c2a90> in <cell line: 2>()
      1 torch.manual_seed(32)
----> 2 model_0=LinearRegressionModel()
      3 model_0
<ipython-input-23-44002508539b> in __init__(self)
      3     def __init__(self):
      4         super().__init__()
----> 5         self.weights=nn.parameter(torch.randn(1,requires_gradient=True,dtype=torch.float))
      6         self.bias=nn.parameter(torch.randn(1,requires_gradient=True,dtype=torch.float))
      7     def forward(self,X:torch.tensor)->torch.tensor:

TypeError: randn() received an invalid combination of arguments - got (int, dtype=torch.dtype, requires_gradient=bool), but expected one of:

 * (tuple of ints size, *, torch.Generator generator, tuple of names names, torch.dtype dtype, torch.layout layout, torch.device device, bool pin_memory, bool requires_grad)
 * (tuple of ints size, *, torch.Generator generator, Tensor out, torch.dtype dtype, torch.layout layout, torch.device device, bool pin_memory, bool requires_grad)
 * (tuple of ints size, *, Tensor out, torch.dtype dtype, torch.layout layout, torch.device device, bool pin_memory, bool requires_grad)
 * (tuple of ints size, *, tuple of names names, torch.dtype dtype, torch.layout layout, torch.device device, bool pin_memory, bool requires_grad)


最初,我试图通过module_0.parameters()打印随机权重和偏差,但由于LinearRegressionModel()未定义,因此我试图仅查看module_0最初返回的内容,但它给出了错误。

yc0p9oo0

yc0p9oo01#

查看torch.randndocs,您的关键字args稍微偏离了(您可以在错误消息中看到这一点,因为它告诉您它得到了不正确的参数)。应该是requires_grad=True而不是requires_gradient=True。另外,请确保创建的是nn.parameter.Parameter而不是nn.parameter,因为nn.parameter是模块的名称,而不是类的名称。

相关问题