import torch
from torch import nn
input = torch.randn(10,3,4,6)
# change the channel dimension as conv in torch need BxCxHxW
input = input.permute(0, 3, 1, 2).contiguous()
# set the group as the number of channels
conv = nn.Conv2d(6, 6, kernel_size=3, stride=1, padding=1, groups=6)
print(conv(input))
1条答案
按热度按时间jmo0nnb31#
深度卷积是你需要的: