如果您使用的是大型数据集,并且没有易于序列化的数据(如高秩numpy数组)来写入tfrecord,那么Train_on_batch的性能也会比fit和fit生成器有所提高。 在这种情况下,您可以将数组保存为numpy文件并加载其中较小的子集(traina.npy,trainb.npy等),当整个集合不适合内存时,您可以使用tf.data.Dataset.from_tensor_slices,然后使用train_on_batch处理您的子数据集,然后加载另一个数据集并再次调用train on batch,等等。现在你已经训练了你的整个数据集,并且可以精确地控制你的数据集训练你的模型的多少和什么。然后你可以定义你自己的时期,批量大小,等等,用简单的循环和函数从你的数据集中抓取。
@nbro answer确实有帮助,只是为了增加一些场景,假设你正在训练一些seq to seq模型或者一个带有一个或多个编码器的大型网络。我们可以使用train_on_batch创建自定义训练循环,并使用我们的一部分数据直接在编码器上验证,而不使用回调。为复杂的验证过程编写回调可能很困难。有几种情况下我们希望在批处理上训练。 问候你,卡希克
5条答案
按热度按时间q3qa4bjr1#
对于这个问题,第一作者给出了一个简单的答案:
使用
fit_generator
,您还可以使用验证数据的生成器。一般来说,我建议使用fit_generator
,但使用train_on_batch
也可以。这些方法只是为了在不同的用例中使用方便,没有“正确”的方法。train_on_batch
允许您根据您提供的样本集合明确更新权重,而不考虑任何固定的批处理大小。您可以在需要时使用此选项:您可以使用这种方法在传统训练集的多个批处理上维护自己的迭代,但允许fit
或fit_generator
为您迭代批处理可能更简单。使用
train_on_batch
的一种情况是,在一批新样本上更新预先训练好的模型。假设您已经训练并部署了一个模型,稍后您收到了一组以前从未使用过的新训练样本。您可以使用train_on_batch
仅在这些样本上直接更新现有模型。其他方法也可以做到这一点。但是对于这种情况使用train_on_batch
是相当明确的。除了像这样的特殊情况(出于某种教学原因,您需要在不同的训练批处理之间维护自己的游标,或者需要在特殊批处理上进行某种类型的半在线训练更新),最好始终使用
fit
(用于适合内存的数据)或fit_generator
(用于将数据流批处理作为生成器)。cngwdvgl2#
train_on_batch()
使您能够更好地控制LSTM的状态,例如,在使用有状态LSTM并需要控制对model.reset_states()
的调用时。您可能有多个系列的数据,并需要在每个系列之后重置状态,这可以使用train_on_batch()
来完成。但是如果你使用.fit()
,那么网络将在所有的数据序列上训练,而不需要重置状态。没有对与错,这取决于你使用的是什么数据,以及您希望网络如何运行。mctunoxg3#
如果您使用的是大型数据集,并且没有易于序列化的数据(如高秩numpy数组)来写入tfrecord,那么Train_on_batch的性能也会比fit和fit生成器有所提高。
在这种情况下,您可以将数组保存为numpy文件并加载其中较小的子集(traina.npy,trainb.npy等),当整个集合不适合内存时,您可以使用tf.data.Dataset.from_tensor_slices,然后使用train_on_batch处理您的子数据集,然后加载另一个数据集并再次调用train on batch,等等。现在你已经训练了你的整个数据集,并且可以精确地控制你的数据集训练你的模型的多少和什么。然后你可以定义你自己的时期,批量大小,等等,用简单的循环和函数从你的数据集中抓取。
kknvjkwl4#
@nbro answer确实有帮助,只是为了增加一些场景,假设你正在训练一些seq to seq模型或者一个带有一个或多个编码器的大型网络。我们可以使用train_on_batch创建自定义训练循环,并使用我们的一部分数据直接在编码器上验证,而不使用回调。为复杂的验证过程编写回调可能很困难。有几种情况下我们希望在批处理上训练。
问候你,卡希克
rwqw0loc5#
来自Keras -模型培训API:
当我们一次使用一批训练数据集更新鉴别器和生成器时,我们可以在GAN中使用它。我在Jason Brownlee的教程(How to Develop a 1D Generative Adversarial Network From Scratch in Keras)中看到使用train_on_batch。
快速搜索提示:键入Control+F,然后在搜索框中键入要搜索的术语(例如train_on_batch)。