我试图建立一个GAN,接受标签生成新的形象。
class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
self.fc1 = nn.Linear(100, 256)
self.fc2 = nn.Linear(256, 512)
self.fc3 = nn.Linear(512, 1024)
self.fc4 = nn.Linear(1024, 784)
def forward(self, x):
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = F.relu(self.fc3(x))
x = torch.tanh(self.fc4(x))
return x
# set label
label = 3
# create one-hot encoded vector
one_hot = torch.zeros(1, 10)
one_hot[0][label] = 1
# set noise vector
noise = torch.randn(1, 100)
# concatenate label and noise
noise_with_label = torch.cat([one_hot, noise], dim=1)
# pass through generator
generated_image = generator(noise_with_label)
但它的投掷:
112
113 def forward(self, input: Tensor) -> Tensor:
--> 114 return F.linear(input, self.weight, self.bias)
115
116 def extra_repr(self) -> str:
RuntimeError: mat1 and mat2 shapes cannot be multiplied (1x110 and 100x256)
我正在使用MNIST数据集。
我试着解决它,但找不到解决的方法。
1条答案
按热度按时间m3eecexj1#
@jodag的原始评论帮助修复了这个问题:
您提供给模型的输入具有shape [1,110],因为您正在连接噪声(形状[1,100])和一个_热(shape [1,10])。但是第一层期望shape [...,100]的输入。消除误差的一种方法是将模型的第一层更改为self.fc1 = nn。线性另一种方法是将噪声定义为噪声= torch.randn(1,90)