当用GPU训练时,我看到损失和val-loss的波动相当大,而且很快就会损失NaN。如果我降低学习率很多,我有时可以防止它损失NaN。但是模型性能相当差。
我的CPU和GPU训练之间的初始化值与我将模型保存到磁盘并从该文件开始训练CPU和GPU相同。
我试过剪辑值,标准值
如果我使用CPU进行训练,我可以使用高学习率,并且我看到损失和val-loss的稳定下降。即使在与GPU相同的学习率下,CPU也可以看到相同数量的epoch的较低损失和val-loss。
我用3个LSTM分支(32个参数输入)训练一个模型,每个LSTM分支有一层512个单元,另一层有128个单元,连接起来,然后是2个265个单元的密集层。我在LSTM层上使用了少量的l2 reg(0.00001),在密集层上使用了少量的dropout(0.05)。我已经尝试了使用和不使用dropout和L2 reg。
我用的是标准均方损失。
模型架构与我的问题没有任何关系,因为我已经尝试了很多不同的层和参数。批量大小没有任何影响。我尝试了非常低的批量大小。我已经尝试了:各种调节技术。剪辑值Norm value较低的学习率GPU burn test(没有错误)很多不同的模型架构。甚至有些没有LSTM层。
我的GPU是一个GTX 1080。我已经尝试了刻录测试。显示0错误。
我使用Tensorflow 2.13 cuDNN 8.6 Nvidia驱动程序535.86.05
这是我的模型:
num_properties = 12
dropout_rate = 0.05
l2_reg_value = 0.00001
input_branch1 = Input(shape=(32, num_properties))
input_branch2 = Input(shape=(32, num_properties))
input_branch3 = Input(shape=(32, num_properties))
lstm_branch1 = LSTM(512, return_sequences=True, kernel_regularizer=l2(l2_reg_value))(input_branch1)
lstm_branch1 = LSTM(128, kernel_regularizer=l2(l2_reg_value))(lstm_branch1)
lstm_branch2 = LSTM(512, return_sequences=True, kernel_regularizer=l2(l2_reg_value))(input_branch2)
lstm_branch2 = LSTM(128, kernel_regularizer=l2(l2_reg_value))(lstm_branch2)
lstm_branch3 = LSTM(512, return_sequences=True, kernel_regularizer=l2(l2_reg_value))(input_branch3)
lstm_branch3 = LSTM(128, kernel_regularizer=l2(l2_reg_value))(lstm_branch3)
concatenated = concatenate([lstm_branch1, lstm_branch2, lstm_branch3])
dense = Dense(256, activation='selu')(concatenated)
dense = Dropout(dropout_rate)(dense)
dense = Dense(256, activation='selu')(dense)
dense = Dropout(dropout_rate)(dense)
output = Dense(1, activation='linear')(dense)
model = Model(inputs=[input_branch1, input_branch2, input_branch3], outputs=output)
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.0001, clipnorm=1), loss="mae")
有什么建议吗?
我刚刚尝试在一台带有A10 GPU的Lambda GPU云机器上使用完全相同的模型和训练数据进行训练。没有问题。
所以问题出在我的PC或TensorFlow等版本上。我试着重新安装Ubuntu,并按照TensorFlow的安装说明操作。同样的问题。
接下来我将尝试另一个GPU。
1条答案
按热度按时间bqf10yzr1#
GTX 1080是问题,即使它没有接缝故障以任何其他方式。
用RTX 3090代替它解决了这个问题。