在Keras中,如何修改每个批次的损失(使用额外的代码,应该在训练期间运行)

lf5gs5x2  于 2023-06-23  发布在  其他
关注(0)|答案(1)|浏览(117)

使用这个自定义回调,我可以1)看到训练过程中的损失2)访问正在训练的模型

class ChangeBatchLoss(tf.keras.callbacks.Callback):
    def on_train_batch_begin(self, batch, logs=None):
        if 't_loss' in logs:
            print(logs, file=sys.stderr)
            print(self.model, file=sys.stderr)

我的问题是:在训练过程中是否可以修改损失本身?(我想执行一些额外的计算,并添加/减去损失(在我的代码中,“损失”对应于logs[“t_loss”]显示的值。
你知道吗?
谢谢

xoefb8l8

xoefb8l81#

1.为相关模型创建自定义模型
1.覆盖function _make_train_function并修改self.metrics_tensors或self.total_loss

相关问题