Tensorflow:如何训练一段时间而不是几个时期?

zpf6vheq  于 2023-03-03  发布在  其他
关注(0)|答案(1)|浏览(119)

既往研究:
Most relevant tensorflow article
How can I calculate the time spent for overall training a model in Tensorflow (for all epochs)?
Show Estimated remaining time to train a model Tensorflow with large epochs
代码:

y = to_categorical(self.ydata, num_classes=self.vocab_size)
model = Sequential()
model.add(Embedding(self.vocab_size, 10, input_length=1))
model.add(LSTM(1000, return_sequences=True))
model.add(LSTM(1000))
model.add(Dense(1000, activation="relu"))
model.add(Dense(self.vocab_size, activation="softmax"))
keras.utils.plot_model(model, show_layer_names=True)
checkpoint = ModelCheckpoint(modelFilePath, monitor='loss', verbose=1,save_best_only=True, mode='auto')
reduce = ReduceLROnPlateau(monitor='loss', factor=0.2,patience=3, min_lr=0.0001, verbose=1)
tensorboard_Visualization = TensorBoard(log_dir=logdirPath)
model.compile(loss="categorical_crossentropy", optimizer=Adam(lr=0.001))
history = model.fit(self.Xdata, y, epochs=epochs, batch_size=64, callbacks=[checkpoint, reduce, tensorboard_Visualization]).history

灵感来源:

  1. https://www.analyticsvidhya.com/blog/2021/08/predict-the-next-word-of-your-text-using-long-short-term-memory-lstm/
  2. https://towardsdatascience.com/building-a-next-word-predictor-in-tensorflow-e7e681d4f03f
    这段代码需要一个单词“问题”和“答案”的列表来训练。如果你在阅读这段代码之前就猜到了模型的目标,那么背景知识会给你留下深刻的印象。无论如何,这段代码是有效的。我只想在这一点上增强它。
    如何在设定的时间内训练模型?一个时期所需的时间因我输入给此AI的文本而异。它变化很大,通常在10秒到4分钟左右。我可以使用它来根据时间估算时期,但如果存在其他方法,我希望从TensorFlow的资源中获得更具体的想法。
yeotifhr

yeotifhr1#

如果必须按照您表述问题的方式定义超时,那么请看一下this answer
然而,准确性会根据你输入的文本发生很大的变化,从极差到过适,所以你最终会花更多的时间来验证,一个更适合你的问题的是定制的EarlyStopping

from tensorflow.keras.callbacks import EarlyStopping

custom_early_stopping = EarlyStopping(
    monitor='val_accuracy', 
    patience=8, 
    min_delta=0.001, 
    mode='max'
)

history = model.fit(
    # rest of parameters
    callbacks=[custom_early_stopping ] # and rest of callbacks
)

在本例中,您将验证精度设置为性能监视器,以确定何时停止训练。我不认为您会使用低精度的训练模型,或者仅仅因为设置了时间而在达到它之后继续训练。
patience=8表示训练在8个时期没有改善时立即终止。min_delta=0.001表示验证精度必须改善至少0.001才能算作改善。mode='max'表示当监视的量停止增加时将停止。

相关问题