pytorch 用于收缩和扰动深度模型训练的Keras兼容代码

vc6uscn9  于 2023-01-26  发布在  其他
关注(0)|答案(1)|浏览(125)

我指的是这项研究https://proceedings.neurips.cc/paper/2020/file/288cd2567953f06e460a33951f55daaf-Paper.pdf“On Warm-Starting Neural Network Training”。在这里,作者提出了一种收缩和扰动技术,以在新到达的数据上重新训练模型。在热启动中,模型使用其先前在旧数据上训练的权重进行初始化,并在新数据上重新训练。在所提出的技术中,现有模型的权重和偏差会向零收缩,然后添加随机噪声。要收缩权重,请将其乘以0到1之间的值,典型值约为0.5。他们的官方pytorch代码可在https://github.com/JordanAsh/warm_start/blob/main/run.py上获得。在https://pureai.com/articles/2021/02/01/warm-start-ml.aspx上给出了该研究的简单解释,作者给出了一个简单的pytorch函数来执行现有模型的收缩和扰动,如下所示:

def shrink_perturb(model, lamda=0.5, sigma=0.01):
  for (name, param) in model.named_parameters():
    if 'weight' in name:   # just weights
      nc = param.shape[0]  # cols
      nr = param.shape[1]  # rows
      for i in range(nr):
        for j in range(nc):
          param.data[j][i] = \
            (lamda * param.data[j][i]) + \
            T.normal(0.0, sigma, size=(1,1))
  return

通过定义的函数,可以使用收缩-扰动技术使用如下代码来初始化预测模型:

net = Net().to(device)
fn = ".\\Models\\employee_model_first_100.pth"
net.load_state_dict(T.load(fn))
shrink_perturb(net, lamda=0.5, sigma=0.01)
# now train net as usual

是否有一个Keras兼容的版本,我们可以缩小权重,并添加随机高斯噪声到现有的模型这样的函数定义?

model = load_model('weights/model.h5')
model.summary()
shrunk_model = shrink_perturn(model,lamda=0.5,sigma=0.01)
shrunk_model.summary()
6mzjoqzu

6mzjoqzu1#

也许是这样的

ws = [w * 0.5 + tf.random.normal(w.shape) for w in model.get_weights()]
model.set_weights(ws)

相关问题