for layer in model.layers: #Find the Batch Norm Layers in the Model
if layer.__class__.__name__ == 'BatchNormalization':
layer.build(layer._build_input_shape)
import keras.backend as K
from keras.callbacks import Callback
class ResetBatchNormWeights(Callback):
def __init__(self, reset_epoch):
super(ResetBatchNormWeights, self).__init__()
self.reset_epoch = reset_epoch
def on_epoch_end(self, epoch, logs=None):
if epoch == self.reset_epoch:
for layer in self.model.layers:
if isinstance(layer, keras.layers.BatchNormalization):
K.set_value(layer.moving_mean, K.zeros_like(layer.moving_mean))
K.set_value(layer.moving_variance, K.ones_like(layer.moving_variance))
# Usage example
reset_epoch = 5 # Reset batchnorm weights after 5 epochs
model = ... # Define your CNN model here
model.fit(x_train, y_train, epochs=10, callbacks=[ResetBatchNormWeights(reset_epoch)])
2条答案
按热度按时间5kgi1eie1#
我想我找到了一种方法来做到这一点,但它可能有点不正统,因为它使用了base_layer. py中的Keras私有变量。
字符串
我将保留这个问题,以防有更好(更“Python”)的解决方案。
oxosxuxt2#
不幸的是,没有内置的方法来重置批量归一化权重(moving_mean和moving_variance),同时保留Keras中学习的CNN权重。
build_from_config
方法不适用于此目的,因为它仅基于提供的配置重建层,而不重置任何内部权重。但是,您可以通过创建一个自定义回调来实现这一点,该回调在一定数量的epoch之后重置批处理规范化权重。下面是一个如何实现的示例:
字符串
在本例中,
ResetBatchNormWeights
回调是用一个reset_epoch
参数创建的,该参数指定应重置批处理规范化权重的时期。在on_epoch_end
方法中,回调函数检查当前epoch是否与reset_epoch
匹配,如果匹配,则重置模型中所有批次归一化层的moving_mean和moving_variance。请注意,此实现假设您正在使用TensorFlow后端。如果您使用的是不同的后端,则可能需要相应地修改代码。