这个问题类似于Tensorflow Keras modify model variable from callback。我无法让那里的解决方案工作(也许自从解决方案发布以来,TensorFlow 2. x已经发生了变化)。
下面是演示代码。我道歉,如果有一个错字。
我想使用回调来更新一个影响层输出的不可训练变量(weighted_add_layer.weight
)。
我尝试过许多变体,例如将tf.keras.backend.set_value(weighted_add_layer.weight, value)
放在update function
中。
在所有情况下,编译模型后,fit
使用编译时weighted_add_layer.weight
的值,以后不再更新该值。
class WeightedAddLayer(tf.keras.layers.Layer):
def __init__(self, weight=0.00, *args, **kwargs):
super(WeightedAddLayer, self).__init__(*args, **kwargs)
self.weight = tf.Variable(0., trainable=False)
def add(self, inputA, inputB):
return (self.weight * inputA + self.weight * inputB)
def update(self, weight):
tf.keras.backend.set_value(self.weight, weight)
input_A = tfkl.Input(
shape=(32),
batch_size=32,
)
input_B = tfkl.Input(
shape=(32),
batch_size=32,
)
weighted_add_layer = WeightedAddLayer()
output = weighted_add_layer.add(input_A, input_B)
model = tfk.Model(
inputs=[input_A, input_B],
outputs=[output],
)
model.compile(
optimizer='adam', loss=losses.MeanSquaredError()
)
# Custom callback function
def update_fun(epoch, steps=50):
weighted_add_layer.update(
tf.clip_by_value(
epoch / steps,
clip_value_min=tf.constant(0.0),
clip_value_max=tf.constant(1.0),)
)
# Custom callback
update_callback = tfk.callbacks.LambdaCallback(
on_epoch_begin=lambda epoch, logs: update_fun(epoch)
)
# train model
history = model.fit(
x=train_data,
epochs=EPOCHS,
validation_data=valid_data,
callbacks=[update_callback],
)
有什么建议吗?非常感谢!
2条答案
按热度按时间xfyts7mz1#
1.这可能是TensorFlow 2.11.0或我的安装或我遗漏的其他东西的问题,但lambda回调的使用非常不稳定,我的代码库和错误检查不断,并且没有做我想做的事情。它还导致了奇怪的行为,使它看起来像是有内存泄漏。完整模型的代码非常复杂,我不“我没有时间进行调试,所以我在分享这些信息的同时提出了一个很大的FWIW警告。
1.在定制Keras模型中,是否有办法在www.example.com()和model.evaluate()的正向传递过程中使层的行为不同?中的代码model.fit可以正常工作。
a.你必须让tf.variable位于一个层内,并且不可训练。我无法让这种方法在层外使用tf.variable。这不是什么大问题,因为你总是可以定义一个微不足道的层,只缩放输入或做一些简单的计算,然后使用该层来完成任务。我发现tf.层外的变量被编译器优化掉了,所以没有办法在编译后更新。
使用Assign作为更新设备效果很好。我尝试了其他方法,但最终使用了Assign。
这是一个回调子类,与演示代码一致。注意,当使用该类时,你必须在调用
fit
时示例化该类的一个示例。你不能传递回调的名称。还要注意,这不是我真实的的代码,而是我写的与上面的演示代码一致的东西。它没有经过测试,可能有错误/错别字。myzjeezk2#
不幸的是,基于this question,这个问题和其中的相关链接,我认为不可能冻结 *
model.compile()
之后的 * 层。在您的情况下,您必须保存,冻结,然后重新编译模型。