Paddle conv bn fusion in ocr end2end model

0md85ypi  于 2021-11-29  发布在  Java
关注(0)|答案(0)|浏览(173)

第一步:dumps trainable variables and moving mean/var


# dump all trainable variable

# --------------------------------------------------------------

 weights={}
for var in tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES):
    print("var name: {}; shape: {}".format(var.name, var.shape))
    weights[var.name] = sess.run(var)
for var in tf.get_collection(tf.GraphKeys.VARIABLES):
    if "moving" in var.name:
        print("var name: {}; shape: {}".format(var.name, var.shape))
        weights[var.name] = sess.run(var)
pickle.dump(weights, open('./dumps/tf_weights.data', 'w'))

第二步: fuse weights of conv and bn

import pickle
import numpy as np

file = "./tf_weights.data"

weights = pickle.load(open(file, 'r'))
epsilon = 1e-5

for key in weights.keys():
    if ('conv' in key or 'Conv' in key or 'shortcut' in key) and ('weights' in key or 'W:0' in key):
        if 'weights' in key:
            base_name, _ = key.split('weights')
        if 'W:0' in key:
            base_name, _ = key.split('W')

        if 'encoder' in key:
            epsilon = 1e-3
            print key

        weight_name = key
        gamma_name = ''.join([base_name, "BatchNorm/gamma:0"])
        beta_name = ''.join([base_name, "BatchNorm/beta:0"])
        mean_name = ''.join([base_name, "BatchNorm/moving_mean:0"])
        variance_name = ''.join([base_name, "BatchNorm/moving_variance:0"])
        bias_name = ''.join([base_name, "biases:0"])

        weight = weights[weight_name]
        if gamma_name in weights:
            gamma = weights[gamma_name]
            beta = weights[beta_name]
            mean = weights[mean_name]
            var = weights[variance_name]
            bias = weights[bias_name] if bias_name in weights else 0
            invs = gamma / np.sqrt(var+epsilon)
            weight =  weight * invs.reshape([1,1,1,invs.shape[0]]) # [filter, filter, in_channel, out_channel]
            bias = (bias-mean)*invs + beta

# print("update weight : [{}]".format(weight_name))

# print("|--Add bias : [{}]".format(bias_name))

            weights[weight_name] = weight
            weights[bias_name] = bias

pickle.dump(weights, open("./tf_bn_fusion.data", 'w'))

print("-----------dumped bn fusion weights into [./tf_bn_fusion.data]---------------")

第三步 load fused weights

  1. remove bn layer in model.
  2. add bias after conv
  3. load weights:
weights = pickle.load(open('./dumps/tf_bn_fusion.data', 'r'))
graph = tf.get_default_graph()
vars = set([var.name for var in tf.get_collection(tf.GraphKeys.VARIABLES)])
for key in weights:
    if key in vars:
        print("load: {}".format(key))
        target = graph.get_tensor_by_name(key)
        sess.run(tf.assign(target, weights[key], validate_shape=True))
    else:
        print("skip: {}".format(key))
print('-----------------loaded fused weights-----------------')

暂无答案!

目前还没有任何答案,快来回答吧!

相关问题