保存tensorflow 联邦模型

yeotifhr  于 2023-03-03  发布在  其他
关注(0)|答案(1)|浏览(139)

我有一个tensorflow 联邦模型,如下所示:

state = iterative_process.initialize()

其中state是封装模型的服务器的状态。
打印它,我们有:

ServerState(model=ModelWeights(trainable=[array([[ 0.01307054,  0.05205479,  0.16566667, ...,
       [0.],
       [0.],
       [0.],
       [0.]], dtype=float32), array([0.], dtype=float32)]), ..., ('zeroed_count_agg', ())]), model_broadcast_state=())

即服务器状态。
我可以使用以下命令访问模型的参数:

state.model.trainable

这是一个列表。
我想做的是保存这个,并在未来重新加载它。
理想情况下,我希望使用此模型更新流程的未来状态(每个联邦迭代返回一个新状态)。
有什么想法吗?
我还发现了this SO线程,但那里的一切似乎都被弃用了。

nbewdwxp

nbewdwxp1#

刚刚找到了一个解决办法。创建一个相同的keras模型。

seq = tf.keras.models.Sequential([
     ...
  ])

将状态的权重设置到此模型,然后保存keras模型:

seq.set_weights(state.model.trainable)
seq.save('...')

相关问题