第一步:dumps trainable variables and moving mean/var
# dump all trainable variable
# --------------------------------------------------------------
weights={}
for var in tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES):
print("var name: {}; shape: {}".format(var.name, var.shape))
weights[var.name] = sess.run(var)
for var in tf.get_collection(tf.GraphKeys.VARIABLES):
if "moving" in var.name:
print("var name: {}; shape: {}".format(var.name, var.shape))
weights[var.name] = sess.run(var)
pickle.dump(weights, open('./dumps/tf_weights.data', 'w'))
第二步: fuse weights of conv and bn
import pickle
import numpy as np
file = "./tf_weights.data"
weights = pickle.load(open(file, 'r'))
epsilon = 1e-5
for key in weights.keys():
if ('conv' in key or 'Conv' in key or 'shortcut' in key) and ('weights' in key or 'W:0' in key):
if 'weights' in key:
base_name, _ = key.split('weights')
if 'W:0' in key:
base_name, _ = key.split('W')
if 'encoder' in key:
epsilon = 1e-3
print key
weight_name = key
gamma_name = ''.join([base_name, "BatchNorm/gamma:0"])
beta_name = ''.join([base_name, "BatchNorm/beta:0"])
mean_name = ''.join([base_name, "BatchNorm/moving_mean:0"])
variance_name = ''.join([base_name, "BatchNorm/moving_variance:0"])
bias_name = ''.join([base_name, "biases:0"])
weight = weights[weight_name]
if gamma_name in weights:
gamma = weights[gamma_name]
beta = weights[beta_name]
mean = weights[mean_name]
var = weights[variance_name]
bias = weights[bias_name] if bias_name in weights else 0
invs = gamma / np.sqrt(var+epsilon)
weight = weight * invs.reshape([1,1,1,invs.shape[0]]) # [filter, filter, in_channel, out_channel]
bias = (bias-mean)*invs + beta
# print("update weight : [{}]".format(weight_name))
# print("|--Add bias : [{}]".format(bias_name))
weights[weight_name] = weight
weights[bias_name] = bias
pickle.dump(weights, open("./tf_bn_fusion.data", 'w'))
print("-----------dumped bn fusion weights into [./tf_bn_fusion.data]---------------")
第三步 load fused weights
- remove bn layer in model.
- add bias after conv
- load weights:
weights = pickle.load(open('./dumps/tf_bn_fusion.data', 'r'))
graph = tf.get_default_graph()
vars = set([var.name for var in tf.get_collection(tf.GraphKeys.VARIABLES)])
for key in weights:
if key in vars:
print("load: {}".format(key))
target = graph.get_tensor_by_name(key)
sess.run(tf.assign(target, weights[key], validate_shape=True))
else:
print("skip: {}".format(key))
print('-----------------loaded fused weights-----------------')
暂无答案!
目前还没有任何答案,快来回答吧!