python 如何使trainloader使用特定数量的图像?

frebpwbc  于 2023-01-01  发布在  Python
关注(0)|答案(1)|浏览(213)

假设我正在使用以下调用:

trainset = torchvision.datasets.ImageFolder(root="imgs/", transform=transform)
trainloader = torch.utils.data.DataLoader(trainset,batch_size=4,shuffle=True,num_workers=1)

据我所知,这将trainset定义为由文件夹“images”中的所有图像组成,标签由特定的文件夹位置定义。
我的问题是-是否有直接/简单的方法将trainset定义为该文件夹中图像的子样本?例如,将trainset定义为每个子文件夹中10张图像的随机样本?

6kkfgxo0

6kkfgxo01#

您可以将类DatasetFolder(或ImageFolder) Package 在另一个类中以限制数据集:

class LimitDataset(data.Dataset):
    def __init__(self, dataset, n):
        self.dataset = dataset
        self.n = n

    def __len__(self):
        return self.n

    def __getitem__(self, i):
        return self.dataset[i]

您还可以在LimitDataset中的索引和原始数据集中的索引之间定义一些Map,以定义更复杂的行为(例如随机子集)。
如果要限制每个时段的批处理数而不是数据集大小:

from itertools import islice
for data in islice(dataloader, 0, batches_per_epoch):
    ...

请注意,如果您使用此随机化,数据集大小将保持不变,但每个时期看到的数据将受到限制。如果您不随机化数据集,这也将限制数据集大小。

相关问题