如何在自定义训练循环中使用tf.keras.layers.BatchNormalization()?

fbcarpbf  于 2023-08-06  发布在  其他
关注(0)|答案(2)|浏览(119)

过了一段时间,我又回到了TensorFlow,看起来情况完全改变了。
然而,以前我在训练循环中使用tf.contrib....batch_normalization和以下代码:

update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
with tf.control_dependencies(update_ops):
    train_op = optimizer.minimize(cnn.loss, global_step=global_step)

字符串
但似乎,contrib无处可寻,tf.keras.layers.BatchNormalization的工作方式也不一样。另外,我在他们的documentation中找不到任何培训说明。
因此,任何帮助的信息是赞赏。

kxxlusnw

kxxlusnw1#

我开始使用PyTorch。它解决了问题。

ajsxfq5m

ajsxfq5m2#

由于批量归一化在训练和推断过程中的行为不同,因此传入call()training变量需要被馈送到BatchNormalization层。

import tensorflow as tf

def call(self, inputs, training=False):
    x = tf.keras.layers.BatchNormalization()(inputs, training=training)
    return x

字符串

相关问题