我有一组相当复杂的模型,我正在训练,我正在寻找一种方法来保存和加载模型优化器状态。“训练器模型”由其他几个“权重模型”的不同组合组成,其中一些具有共享权重,一些具有取决于训练器的冻结权重,等等。这是一个有点太复杂的例子来分享,但简而言之,停止和开始训练时,我无法使用model.save('model_file.h5')
和keras.models.load_model('model_file.h5')
。
如果训练已经完成,使用model.load_weights('weight_file.h5')
可以很好地测试我的模型,但是如果我尝试使用这个方法继续训练模型,损失甚至不会接近返回到它的上一个位置。我已经读到这是因为优化器状态没有使用这个方法保存,这是有意义的。但是,我需要一个方法来保存和加载我的训练器模型的优化器的状态。看起来好像keras曾经有一个model.optimizer.get_sate()
和model.optimizer.set_sate()
来完成我所追求的,但现在似乎不再是这样了(至少对于Adam优化器来说是这样)。当前的Keras还有其他解决方案吗?
7条答案
按热度按时间eiee3dmh1#
您可以从
load_model
和save_model
函数中提取重要的行。对于保存优化器状态,在
save_model
中:对于加载优化器状态,在
load_model
中:下面是一个示例,将上面的行组合在一起:
1.首先拟合5个时期的模型。
1.现在保存权重和优化器状态。
1.在另一个python会话中重建模型,并加载权重。
1.继续模型培训。
7gyucuyw2#
对于那些不使用
model.compile
而是使用optimizer.apply_gradients
执行自动微分来手动应用梯度的人,我想我有一个解决方案。首先,保存优化器权重:
np.save(path, optimizer.get_weights())
然后,当您准备好重新加载优化器时,向新示例化的优化器显示它将更新的权重大小,方法是在计算梯度的变量大小的Tensor列表上调用
optimizer.apply_gradients
。在设置优化器的权重之后设置模型的权重是非常重要的,因为动量-即使我们给予它的梯度为零,像Adam这样的基于优化器也将更新模型的权重。注意,如果我们在第一次调用
apply_gradients
之前尝试设置权重,则会抛出一个错误,即优化器期望得到一个长度为零的权重列表。ha5z0ras3#
完成Alex Trevithick的回答后,可以避免重新调用
model.set_weights
,只需在应用梯度之前保存变量的状态,然后重新加载。这在从h5文件加载模型时很有用,看起来更干净(imo)。保存/加载功能如下(再次感谢Alex):
ccrfmcuu4#
将Keras升级到2.2.4并使用Pickle为我解决了这个问题。随着Keras发布2.2.3 Keras模型现在可以安全地进行Pickle。
zi8p0yeb5#
任何尝试在分布式设置中使用@Yu-Yang的solution的人都可能会运行以下错误:
或类似的。
要解决此问题,您只需使用以下命令在每个副本上运行模型的优化器权重设置:
出于某种原因,设置模型权重时不需要这样做,但请确保在策略范围内创建(通过此处的调用)并加载模型的权重,否则可能会得到
ValueError: Trying to create optimizer slot variable under the scope for tf.distribute.Strategy (<tensorflow.python.distribute.collective_all_reduce_strategy.CollectiveAllReduceStrategy object at 0x14ffdce82c50>), which is different from the scope used for the original variable
之类的错误。如果你想要完整的例子,我创建了a colab showcasing this solution。
6mw9ycah6#
下面的代码适合我(Tensorflow 2.5)。
我使用通用语句编码器作为模型,同时使用Adam优化器。
基本上我做的是:我使用了一个虚拟输入来正确地设置优化器。
之后我设定了重量。
保存优化程序的权重
加载优化程序
q5iwbnjs7#
从版本2.11开始,
optimizer.get_weights()
不再可用。你可以最终切换到tf.optimizers.legacy类,但不推荐。相反,类tf.train.Checkpoint是专门为保存模型和优化器权重而设计的:
最后,类tf.train.CheckpointManager管理多个检查点版本,使其变得非常简单: