pytorch 加载不对齐预训练

x33g5p2x  于2021-11-15 转载在 其他  
字(0.6k)|赞(0)|评价(0)|浏览(269)

以前改变网络通道数,需要重新从头训练,无法加载预训练,今天研究了一下如何改变网络通道后,还有预训练模型可用,这样可以减少980%的训练时间,提供训练效率。

废话不说,直接上代码:

这个代码加载预训练模型后,再训练无效果:

backbone = MobileFace_83_w(256,l_size=[2,6,8,4]).to(0)

    backbone_pth = os.path.join("/data/408800_net.pth")

    state_dict=torch.load(backbone_pth, map_location=torch.device(0))
    # backbone.load_state_dict(state_dict,strict=False)
    bone_dict=backbone.state_dict()

    # model_end=Model_end(256).to(0)

    new_state_dict = OrderedDict()
    for k, v in state_dict.items():
        head = k[:7]
        if head == 'module.':
            tmp_name = k[7:]  # remove `module.`
        else:
            tmp_name = k
            # continue
        need_v= bone_dict[tmp_name]

        if len(need_v.size())==1:
            if need_v.size(0)>v.size(0):

相关文章