我试着用PyTorch创建一个生成式广告系列网络。我编写了鉴别器块并打印了摘要。之后,我创建了生成器块。我定义了forward()函数,并将输入噪声维度从(批量大小,噪声调暗)至(batch_size,channel,height,width)。在运行用于获取摘要的代码时,弹出的错误说'NoneType'对象是不可调用的。我搜索了stackoverflow和其他地方,但我的问题没有解决。
我首先使用以下代码创建了一个生成器块函数:
def get_gen_block(in_channels, out_channels, kernel_size, stride, final_block = False):
if final_block == True:
return nn.Sequential(
nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride),
nn.Tanh()
)
return nn.Sequential(
nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride),
nn.BatchNorm2d(out_channels),
nn.ReLU()
)
然后我为生成器块定义了一个类来创建几个类,并定义了forward()函数链接如下:
class Generator(nn.Module):
def __init__(self, noise_dim):
super(Generator, self).__init__()
self.noise_dim = noise_dim
self.block_1 = get_gen_block(noise_dim, 256, (3, 3), 2)
self.block_2 = get_gen_block(256, 128, (4, 4), 1)
self.block_3 = get_gen_block(128, 64, (3, 3), 2)
self.block_4 = get_gen_block(64, 1, (4, 4), 2, final_block=True)
def forward(self, r_noise_vec):
x = r_noise_vec.view(-1, self.noise_dim, 1, 1)
x1 = self.block_1(x)
x2 = self.block_2(x1)
x3 = self.block_3(x2)
x4 = self.block_4(x3)
return x4
之后,当我打印发生器的摘要时,指向行“x1= self.block_1(x)”时发生此错误,显示'NoneType' object is not callable
。任何人请帮助我解决此问题。
1条答案
按热度按时间uajslkp61#
请检查您的
get_gen_block
函数,看起来您错过了else:
分支或混淆了缩进,当final_block = False
时,它返回None
而不是当满足条件时总是返回
module1
,否则返回None
。当满足条件时,返回
module1
,否则返回module2
。现在比较缩进。