网络结构代码:
# -*- coding: utf-8 -*-
import os
import time
import numpy as np
import cv2
import torch
import torch.nn as nn
import torch.nn.functional as torch_f
from torch.nn import init
def weights_init_normal(m):
classname = m.__class__.__name__
#print(classname)
if classname.find('Conv') != -1:
init.normal_(m.weight.data, 0.0, 0.02)
elif classname.find('Linear') != -1:
init.normal_(m.weight.data, 0.0, 0.02)
elif classname.find('BatchNorm') != -1:
init.normal_(m.weight.data, 1.0, 0.02)
init.constant_(m.bias.data, 0.0)
def weights_init_xavier(m):
classname = m.__class__.__name__
#print(classname)
if classname.find('Conv') != -1:
init.xavier_normal_(m.weight.data, gain=1)
elif classname.find('Linear') != -1:
init.xavier_normal_(m.weight.data, gain=1)
elif classname.find('Ba
版权说明 : 本文为转载文章, 版权归原作者所有 版权申明
原文链接 : https://blog.csdn.net/jacke121/article/details/122375415
内容来源于网络,如有侵权,请联系作者删除!