python 如何防止TensorFlow模型在使用超过17个数据对进行训练时将inf作为损失值

lmyy7pcs  于 2023-03-21  发布在  Python
关注(0)|答案(1)|浏览(97)

我正在尝试训练一个预测模型,对于输入x预测输出x^2。我对ai相当陌生,并尝试了一些类似的事情。虽然使用长度小于18的xs和ys,一切都很好,但只要长度〉= 18,损失值首先开始变得非常高,最终,经过3或4个时期,达到无穷大。这导致预测结果也为空。

import tensorflow as tf
import numpy as np
from tensorflow import keras

def quadratVonX(y_new):
    xs = np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30], dtype=float)
    ys = np.array([0, 1, 4, 9, 16, 25, 36, 49, 64, 81, 100, 121, 144, 169, 196, 225, 256, 289, 324, 361, 400, 441, 484, 529, 576, 625, 676, 729, 784, 841, 900], dtype=float)
    model = tf.keras.Sequential([keras.layers.Dense(units=1, input_shape=[1])])

    model.compile(optimizer='sgd', loss='mean_squared_error')
    model.fit(xs, ys, epochs=50)
    return model.predict(y_new)[0]

prediction = quadratVonX([15])
print(prediction)

其中一个时期的样本输出:

Epoch 50/50
1/1 [==============================] - 0s 1ms/step - loss: inf
1/1 [==============================] - 0s 58ms/step
[-inf]

我希望不会得到一个否定的/不存在的结果,就像我使用长度小于18的xs和ys时所做的那样。

db2dz4w8

db2dz4w81#

我怀疑学习率太高了。default for SGD是0.01。我用0.001的学习率尝试了你的例子,它收敛了。

import tensorflow as tf
import numpy as np

def quadratVonX(y_new):
    xs = np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30], dtype=float)
    ys = np.array([0, 1, 4, 9, 16, 25, 36, 49, 64, 81, 100, 121, 144, 169, 196, 225, 256, 289, 324, 361, 400, 441, 484, 529, 576, 625, 676, 729, 784, 841, 900], dtype=float)
    model = tf.keras.Sequential([tf.keras.layers.Dense(units=1, input_shape=(1,))])

    model.compile(optimizer=tf.keras.optimizers.SGD(learning_rate=0.001), loss='mean_squared_error')
    model.summary()
    model.fit(xs, ys, epochs=50)
    return model.predict(y_new)

prediction = quadratVonX([15])
print(prediction)

结果:

Epoch 50/50
1/1 [==============================] - 0s 948us/step - loss: 10422.5850
1/1 [==============================] - 0s 35ms/step
[[342.24048]]

相关问题