pytorch 简单线性回归模型参数未更新

cvxl0en2  于 12个月前  发布在  其他
关注(0)|答案(1)|浏览(92)

我一直在学习PyTorch的一个教程,创建一个简单线性回归模型。
我已经跟进了相应的这里的代码供参考。

class LinearRegression(nn.Module):

  def __init__(self):
    super().__init__()

    self.bias = nn.Parameter(torch.randn(1, requires_grad = True, dtype = torch.float))
    self.weights = nn.Parameter(torch.randn(1, requires_grad = True, dtype = torch.float))

  def forward(self, x: torch.Tensor) -> torch.Tensor:
    return self.weights * x + bias

我使用SGD作为优化器,学习率为0.01。并以MAE作为损失函数。
无论我调整什么参数,偏置参数都不会改变。体重变化很好。
我见过几个Stackoverflow线程谈论克隆参数,但它在我的情况下不起作用。
我已经初始化了模型:

torch.manual_seed(42)

model = LinearRegression()

list(model.parameters())

当我打印参数时,它很好。我哪里做错了。

nszi6y05

nszi6y051#

你可能在forward函数中犯了一个错误,试试下面的代码:

def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.weights * x + self.bias  # Fix here: Use self.bias instead of just bias

相关问题