pytorch mat1和mat2形状不能相乘(1x110和100x256)

ycl3bljg  于 2023-02-16  发布在  其他
关注(0)|答案(1)|浏览(278)

我试图建立一个GAN,接受标签生成新的形象。

class Generator(nn.Module):
       def __init__(self):
           super(Generator, self).__init__()
           self.fc1 = nn.Linear(100, 256)
           self.fc2 = nn.Linear(256, 512)
           self.fc3 = nn.Linear(512, 1024)
           self.fc4 = nn.Linear(1024, 784)

       def forward(self, x):
           x = F.relu(self.fc1(x))
           x = F.relu(self.fc2(x))
           x = F.relu(self.fc3(x))
           x = torch.tanh(self.fc4(x))
           return x
   

# set label
   label = 3

   # create one-hot encoded vector
   one_hot = torch.zeros(1, 10)
   one_hot[0][label] = 1

   # set noise vector
   noise = torch.randn(1, 100)

   # concatenate label and noise
   noise_with_label = torch.cat([one_hot, noise], dim=1)

   # pass through generator
   generated_image = generator(noise_with_label)

但它的投掷:

112 
    113     def forward(self, input: Tensor) -> Tensor:
--> 114         return F.linear(input, self.weight, self.bias)
    115 
    116     def extra_repr(self) -> str:

RuntimeError: mat1 and mat2 shapes cannot be multiplied (1x110 and 100x256)

我正在使用MNIST数据集。
我试着解决它,但找不到解决的方法。

m3eecexj

m3eecexj1#

@jodag的原始评论帮助修复了这个问题:
您提供给模型的输入具有shape [1,110],因为您正在连接噪声(形状[1,100])和一个_热(shape [1,10])。但是第一层期望shape [...,100]的输入。消除误差的一种方法是将模型的第一层更改为self.fc1 = nn。线性另一种方法是将噪声定义为噪声= torch.randn(1,90)

相关问题