我们发现tf.keras.layers.BatchNormalization的实现并不符合其数学模型,问题的原因可能来自于它的epsilon或variance参数,错误的发生具体分为四个步骤:
(1)初始化BN算子(即source_model),随机输入输入(即data),并得到输出(即source_result);
(2)随机产生一个扰动(即delta),将source_model的方差加到delta上,并从epsilon中减去delta,得到一个新的BN算子(即follow_model);
(3)向follow-model输入数据,得到follow-result;
(4)计算source_result和follow_result之间的距离,理论上应该很小甚至为0,实际上可以得到大于1的结果
# from tensorflow.keras.layers import BatchNormalization, Input
# from tensorflow.keras.models import Model, clone_model
from tensorflow._api.v1.keras.layers import BatchNormalization, Input
from tensorflow._api.v1.keras.models import Model, clone_model
import os
import re
import numpy as np
def SourceModel(shape):
x = Input(shape=shape[1:])
y = BatchNormalization(axis=-1)(x)
return Model(x, y)
def FollowModel_1(source_model):
follow_model = clone_model(source_model)
# read weights
weights = source_model.get_weights()
weights_names = [weight.name for layer in source_model.layers for weight in layer.weights]
variance_idx = FindWeightsIdx("variance", weights_names)
# mutation operator
# delta = np.random.uniform(-1e-3, 1e-3, 1)[0]
follow_model.layers[1].epsilon += delta # mutation epsilon
weights[variance_idx] -= delta
follow_model.set_weights(weights)
return follow_model
def FindWeightsIdx(name, weights_names):
# find layer index by name
for idx, names in enumerate(weights_names):
if re.search(name, names):
return idx
return -1
os.environ['TF_FORCE_GPU_ALLOW_GROWTH'] = 'true'
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
shape = (10, 32, 32, 3)
data = np.random.uniform(-1, 1, shape)
delta = -1
source_model = SourceModel(shape)
follow_model = FollowModel_1(source_model)
source_result = source_model.predict(data)
follow_result = follow_model.predict(data)
dis = np.sum(abs(source_result-follow_result))
print("delta:", delta, "; dis:", dis)
delta再大,dis也应该小一些,但实际上并没有,这说明张流的batch-norm算子可能存在bug,tf1.x和tf2.x都有这个问题
delta: -1 ; dis: 4497.482
1条答案
按热度按时间polhcujo1#
我不认为这是一个错误,而是一个未记录的行为。
我注意到,对于您的代码,我没有得到任何
delta > 0
的差异,或者实际上任何delta > -0.001
--默认的epsilon是0.001
,所以使用较大的delta意味着我们仍然有epsilon > 0
。任何较大的负delta(特别是在您的示例中为-1)将导致epsilon < 0
。epsilon < 0
有问题吗?是的,因为在除以方差时需要防止除以0。方差总是〉0,所以在这里减去某个值可能会导致除以0。更相关的问题是,epsilon与方差相加,然后取平方根以获得标准差,如果variance + epsilon
〈0,则会崩溃。这对于小方差和epsilon < 0
可能发生。我的直觉是,在代码的某个地方,他们做了一些类似
abs(epsilon)
的事情来防止这样的问题,但是,我在层实现中找不到任何东西,在层使用的op中也找不到。然而,BN默认使用一个“融合”的实现,这样更快。这是这里。我们在这里看到这些行:
因此,epsilon确实总是正的。我唯一不明白的是,你可以将
fused=False
传递给BN构造函数,而这应该使用更基本的实现,在那里我找不到修改epsilon的任何东西。但是当我测试它时,问题仍然存在。不确定问题是什么...tl;dr:你有
epsilon < 0
。不要这样做,这很糟糕。