如何使用pyTorch指定conv2D图层中的批次维度

dgenwo3n  于 2023-02-08  发布在  其他
关注(0)|答案(1)|浏览(147)

我有一个600x600灰度图像的数据集,通过数据加载器将其分组为50张图像。
我的网络有一个包含16个过滤器的卷积层,接着是包含6x6内核的Maxpooling,然后是一个Dense层。conv2D的输出应该是out_channels*width*height/maxpool_kernel_W/maxpool_kernel_H = 16*600*600/6/6 = 160000乘以批大小50。
然而,当我尝试做一个向前传递我得到以下错误:我验证了数据的格式是否正确为[batch,n_channels,width,height](在我的例子中为[50,1,600,600])。
逻辑上输出应该是一个50x160000的矩阵,但显然它被格式化为一个80000x100的矩阵。看起来 Torch 是乘矩阵沿着错误的维度。如果有人明白为什么,请帮助我明白了。

# get data (using a fake dataset generator)
dataset = FakeData(size=500, image_size= (1, 600, 600), transform=ToTensor())
training_data, test_data = random_split(dataset,[400,100])
train_dataloader = DataLoader(training_data, batch_size=50, shuffle=True)
test_dataloader  = DataLoader(test_data, batch_size=50, shuffle=True)

net = nn.Sequential(
    nn.Conv2d(
                in_channels=1,              
                out_channels=16,            
                kernel_size=5,                     
                padding=2,           
            ),
    nn.ReLU(),  
    nn.MaxPool2d(kernel_size=6),
    nn.Linear(160000, 1000),
    nn.ReLU(),
)

optimizer = optim.Adam(net.parameters(), lr=1e-3,)

epochs = 10
for i in range(epochs):
    for (x, _) in train_dataloader:
        optimizer.zero_grad()

        # make sure the data is in the right shape
        print(x.shape) # returns torch.Size([50, 1, 600, 600])

        # error happens here, at the first forward pass
        output = net(x)

        criterion = nn.MSELoss()
        loss = criterion(output, x)
        loss.backward()
        optimizer.step()
qhhrdooz

qhhrdooz1#

如果逐层检查模型的推理,您会注意到nn.MaxPool2d返回形状为(50, 16, 100, 100)的4DTensor。例如,如果你想使空间维度变平,这将产生形状为(50, 16*100*100)的Tensor,ie.(50, 160_000)正如你所期望的那样。也就是说你需要使用nn.Flatten层。

net = nn.Sequential(nn.Conv2d(in_channels=1, out_channels=16, kernel_size=5, padding=2),
                    nn.ReLU(),  
                    nn.MaxPool2d(kernel_size=6),
                    nn.Flatten(),
                    nn.Linear(160000, 1000),
                    nn.ReLU())

相关问题