pytorch 使用dataloader创建一批不同形状的Tensor

rqdpfwrv  于 2022-12-18  发布在  其他
关注(0)|答案(1)|浏览(231)

我有一个错误

stack expects each tensor to be equal size, but got [15, 414] at entry 0 and [31, 414] at entry 1

我正在尝试创建一个解决方案,通过在Dataloader中使用list并借助自定义collate_fn创建一批具有不同形状的Tensor,来解决除补零和采样之外的这个问题。
我试过了,但还是有问题

`def collate_fn(batch):
    num_features = len(batch[0])
    result = [[] for _ in range(batch[0])]
    for data in batch:
       for i in range(num_features):
           result[i].append(data[i]) 

    return result `

如何修复此功能?
我想得到这个

batch X num_exampls X imgs


其中,imgs是img数量的列表,img_features为[num_examples, [imgs, features]]

krugob8w

krugob8w1#

你可以试试

def collate_fn(batch):
    return [torch.from_numpy(x) for x in batch]

并将其传递给DataLoader。这样,DataLoader返回的每个批处理都将是一个Tensor列表,并且torch.stack不会在列表中被调用。但是,collate_fn期望Dataset.__getitem__返回一个NumPy数组(可以是不同的形状)。例如:

# An example image dataset
class ImageDataset(Dataset):
    imgs = [np.random.rand(2, 3), np.random.rand(4, 5)]  # different shapes are OK

    def __getitem__(self, index):
        return self.imgs[index]

    def __len__(self):
        return len(self.imgs)

loader = DataLoader(ImageDataset(), batch_size=2, collate_fn=collate_fn)
batches = list(loader)

assert len(batches) == 1
assert type(batches[0]) == list
assert len(batches[0]) == 2
assert all(torch.is_tensor(x) for x in batches[0])
assert [x.shape for x in batches[0]] == [(2, 3), (4, 5)]

相关问题