pyTorch中的简单线性回归-为什么损失随着每个时期而增加?

h43kikqp  于 2023-04-06  发布在  其他
关注(0)|答案(2)|浏览(146)

我尝试用PyTorch建立一个简单的线性回归模型,根据实际温度temp预测感知温度atemp
我不明白为什么这个代码的结果是损失随着每一个历元而增加,而不是减少。所有的预测值都与事实相差甚远。

使用的样本数据

data_x = array([11.9, 12. , 13.4, 14.8, 15.8, 16.6, 16.7, 16.9, 16.9, 16.9, 16.5,
       15.7, 15.3, 15. , 15. , 14.9, 14.6, 14.2, 14.2, 14. , 13.5, 12.9,
       12.5, 12.4, 12.8, 14.3, 15.6, 16.5, 17. , 17.5, 17.7, 17.7, 17.8,
       17.5, 16.9, 15.6, 14. , 12.2, 11. , 10.6, 10.6, 10.7, 10.9, 10.6,
       10.3,  9.4,  8.7,  7.8,  8.1, 11. , 13.4, 15.2, 16.5, 17.4, 18.1,
       18.5, 18.7, 18.6, 17.7, 16. , 14.6, 13.8, 13. , 12.5, 12. , 11.8,
       11.5, 11.3, 10.9, 10.6, 10.2,  9.9, 10.5, 13.1, 15.3, 17.2, 18.9,
       20.3, 21.2, 21.8, 21.9, 21.5, 20.2, 18.3, 16.8, 15.8, 14.9, 14.2,
       13.6, 13.2, 12.9, 12.7, 12.6, 12.6, 12.6, 12.8, 13.4, 15.5, 17.6,
       19.3])
data_y = array([ 8.9,  9.3, 10.7, 12.1, 13.1, 13.8, 14. , 14.1, 14.3, 14.5, 14.3,
       13.7, 13.2, 12.7, 12.7, 12.5, 11.9, 11.7, 11.7, 11.5, 11.1, 10.6,
       10.3, 10.2, 10.9, 12.5, 12.8, 13.8, 14.6, 14.9, 14.9, 15.1, 15.5,
       15.6, 15.8, 14.7, 13.1, 11.2,  9.6,  9.1,  9.4,  9.7,  9.9,  9.6,
        9.2,  8. ,  7.1,  6.1,  6.5, 10.2, 12.7, 14.3, 15.5, 16.6, 17.4,
       17.7, 17.8, 17.6, 17.2, 15.3, 13.4, 12.4, 11.5, 10.8, 10.1, 10. ,
        9.8,  9.6,  9.3,  9. ,  8.5,  8.1,  8.8, 12. , 14.4, 16.6, 18.5,
       20.1, 21. , 21.3, 21.2, 21.2, 20.1, 17.9, 16.1, 14.6, 13.8, 13.1,
       12.3, 11.8, 11.6, 11.4, 11.3, 11.3, 11.3, 11.4, 12. , 14.6, 16.8,
       18.8])

绘制数据:

编号

# import data from CSV to pandas Dataframe
bg = pd.read_csv('data.csv')
X_pandas = bg['temp']
y_pandas = bg['atemp']

# covert to tensors
data_x = X_pandas.head(100).values
data_y = y_pandas.head(100).values
X = torch.tensor(data_x, dtype=torch.float32).reshape(-1, 1)
y = torch.tensor(data_y, dtype=torch.float32).reshape(-1, 1)

# create the model
model = nn.Linear(1, 1)
loss_fn = nn.MSELoss()  # mean square error
optimizer = optim.SGD(model.parameters(), lr=0.01)

# train the model
n_epochs = 40   # number of epochs to run
for epoch in range(n_epochs):
    # forward pass
    y_pred = model(X)
    # compute loss
    loss = loss_fn(y_pred, y)
    # backward pass
    loss.backward()
    # update parameters
    optimizer.step()
    # zero gradients
    optimizer.zero_grad()
    # print loss
    print(f'epoch: {epoch + 1}, loss = {loss.item():.4f}')

# display the predicted values
predicted = model(X).detach().numpy()
display(predicted)

输出

epoch: 1, loss = 16.5762
epoch: 2, loss = 191.0379
epoch: 3, loss = 2291.5081
epoch: 4, loss = 27580.5195
epoch: 5, loss = 332052.6875
epoch: 6, loss = 3997804.2500
epoch: 7, loss = 48132328.0000
epoch: 8, loss = 579498624.0000
epoch: 9, loss = 6976988160.0000
epoch: 10, loss = 84000866304.0000
epoch: 11, loss = 1011344670720.0000
epoch: 12, loss = 12176279470080.0000
epoch: 13, loss = 146598776537088.0000
epoch: 14, loss = 1765004462260224.0000
epoch: 15, loss = 21250117348622336.0000
epoch: 16, loss = 255844948350337024.0000
epoch: 17, loss = 3080297218377252864.0000
epoch: 18, loss = 37085819119396192256.0000
epoch: 19, loss = 446502312996857970688.0000
epoch: 20, loss = 5375748153858603352064.0000
epoch: 21, loss = 64722396677244886974464.0000
epoch: 22, loss = 779237667397586303057920.0000
epoch: 23, loss = 9381773651754967424303104.0000
epoch: 24, loss = 112953739724808869434621952.0000
epoch: 25, loss = 1359928800566679308764971008.0000
epoch: 26, loss = 16373128158657455337028714496.0000
epoch: 27, loss = 197127444146361433227589058560.0000
epoch: 28, loss = 2373354706586702693378941779968.0000
epoch: 29, loss = 28574463232459721913615454830592.0000
epoch: 30, loss = 344027831021918449557295178186752.0000
epoch: 31, loss = 4141990153063893156517557464727552.0000
epoch: 32, loss = 49868270370463502095675094080684032.0000
epoch: 33, loss = 600398977963427833849804206813216768.0000
epoch: 34, loss = inf
epoch: 35, loss = inf
epoch: 36, loss = inf
epoch: 37, loss = inf
epoch: 38, loss = inf
epoch: 39, loss = inf
epoch: 40, loss = inf

预测值:

array([[1.60481241e+21],
       [1.61822441e+21],
       [1.80599158e+21],
       [1.99375890e+21],
       [2.12787834e+21],
       [2.23517393e+21],
       [2.24858593e+21],
       [2.27540965e+21],
       [2.27540965e+21],
       [2.27540965e+21],
       ...

是什么原因导致了这个奇怪的结果?

axzmvihb

axzmvihb1#

看来我的问题是0.01的学习率对于这个问题和数据量来说太高了。
更改此位修复了问题:
optimizer = optim.SGD(model.parameters(), lr=0.01)

optimizer = optim.SGD(model.parameters(), lr=0.005)

kmpatx3s

kmpatx3s2#

缩放将在不改变LR的情况下有所帮助

# X_pandas = bg['temp']
# y_pandas = bg['atemp']

data_x = data_x/data_x.max()
data_y = data_y

# covert to tensors
# data_x = X_pandas.head(100).values
# data_y = y_pandas.head(100).values
X = torch.tensor(data_x, dtype=torch.float32).reshape(-1, 1)
y = torch.tensor(data_y, dtype=torch.float32).reshape(-1, 1)

# create the model
model = nn.Linear(1, 1)
loss_fn = nn.MSELoss()  # mean square error
optimizer = optim.SGD(model.parameters(), lr=0.01)

# train the model
n_epochs = 40   # number of epochs to run
for epoch in range(n_epochs):
    # forward pass
    y_pred = model(X)
    # compute loss
    loss = loss_fn(y_pred, y)
    # backward pass
    loss.backward()
    # update parameters
    optimizer.step()
    # zero gradients
    optimizer.zero_grad()
    # print loss
    print(f'epoch: {epoch + 1}, loss = {loss.item():.4f}')

# display the predicted values
predicted = model(X).detach().numpy()
display(predicted)

epoch: 1, loss = 210.4702
epoch: 2, loss = 198.8098
epoch: 3, loss = 187.8156
epoch: 4, loss = 177.4496
epoch: 5, loss = 167.6758
epoch: 6, loss = 158.4604
epoch: 7, loss = 149.7714
epoch: 8, loss = 141.5789
epoch: 9, loss = 133.8544
epoch: 10, loss = 126.5711
epoch: 11, loss = 119.7039
epoch: 12, loss = 113.2290
epoch: 13, loss = 107.1239
epoch: 14, loss = 101.3676
epoch: 15, loss = 95.9400
epoch: 16, loss = 90.8225
epoch: 17, loss = 85.9972
epoch: 18, loss = 81.4475
epoch: 19, loss = 77.1577
epoch: 20, loss = 73.1128
epoch: 21, loss = 69.2989
epoch: 22, loss = 65.7028
epoch: 23, loss = 62.3120
epoch: 24, loss = 59.1148
epoch: 25, loss = 56.1002
epoch: 26, loss = 53.2576
epoch: 27, loss = 50.5773
epoch: 28, loss = 48.0500
epoch: 29, loss = 45.6670
epoch: 30, loss = 43.4200
epoch: 31, loss = 41.3012
epoch: 32, loss = 39.3033
epoch: 33, loss = 37.4193
epoch: 34, loss = 35.6429
epoch: 35, loss = 33.9678
epoch: 36, loss = 32.3883
epoch: 37, loss = 30.8988
epoch: 38, loss = 29.4943
epoch: 39, loss = 28.1699
epoch: 40, loss = 26.9209
array([[ 8.267589 ],
       [ 8.287247 ],
       [ 8.562464 ],
       [ 8.837682 ],
       [ 9.0342655],
       [ 9.191533 ],
       [ 9.211191 ],
       [ 9.250508 ],
       [ 9.250508 ],
       [ 9.250508 ],
       [ 9.171875 ],
       [ 9.014607 ],
       [ 8.935974 ],
       [ 8.876999 ],
       [ 8.876999 ],
       [ 8.85734  ],
       [ 8.798365 ],
       [ 8.719731 ],
       [ 8.719731 ],
       [ 8.680414 ],
       [ 8.582123 ],
       [ 8.464172 ],
       [ 8.385539 ],
       [ 8.36588  ],
       [ 8.444513 ],
       [ 8.739389 ],
       [ 8.994949 ],
       [ 9.171875 ],
       [ 9.270166 ],
       [ 9.368459 ],
       [ 9.407776 ],
       [ 9.407776 ],
       [ 9.427434 ],
       [ 9.368459 ],
       [ 9.250508 ],
       [ 8.994949 ],
       [ 8.680414 ],
       [ 8.326564 ],
       [ 8.090663 ],
       [ 8.012029 ],
       [ 8.012029 ],
       [ 8.031688 ],
       [ 8.071004 ],
       [ 8.012029 ],
       [ 7.9530535],
       [ 7.7761283],
       [ 7.6385193],
       [ 7.4615936],
       [ 7.520569 ],
       [ 8.090663 ],
       [ 8.562464 ],
       [ 8.916315 ],
       [ 9.171875 ],
       [ 9.348801 ],
       [ 9.486409 ],
       [ 9.5650425],
       [ 9.60436  ],
       [ 9.584702 ],
       [ 9.407776 ],
       [ 9.073583 ],
       [ 8.798365 ],
       [ 8.641098 ],
       [ 8.48383  ],
       [ 8.385539 ],
       [ 8.287247 ],
       [ 8.24793  ],
       [ 8.188955 ],
       [ 8.149638 ],
       [ 8.071004 ],
       [ 8.012029 ],
       [ 7.9333954],
       [ 7.87442  ],
       [ 7.9923706],
       [ 8.503489 ],
       [ 8.935974 ],
       [ 9.309484 ],
       [ 9.643677 ],
       [ 9.918894 ],
       [10.095819 ],
       [10.21377  ],
       [10.233429 ],
       [10.154795 ],
       [ 9.899236 ],
       [ 9.525726 ],
       [ 9.23085  ],
       [ 9.0342655],
       [ 8.85734  ],
       [ 8.719731 ],
       [ 8.601781 ],
       [ 8.523148 ],
       [ 8.464172 ],
       [ 8.424855 ],
       [ 8.405197 ],
       [ 8.405197 ],
       [ 8.405197 ],
       [ 8.444513 ],
       [ 8.562464 ],
       [ 8.97529  ],
       [ 9.388117 ],
   [ 9.72231  ]], dtype=float32)

相关问题