我有一个torchvision.datasets
对象。我只想保留一些标签并删除其他标签。
例如,如果我的数据集是CFAR10,像这样的trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True)
,我将有10个标签。我只想保留前三个标签,删除其他标签。我怎么做?
P.S:我想我可以通过像这样从头开始构建一个数据集对象来做到这一点。但我猜应该有一种更短的方法来做到这一点:
class FilteredDataset(torch.utils.data.Dataset):
def __init__(self, dataset, desired_labels):
self.dataset = dataset
self.indices = [i for i, (_, target) in enumerate(self.dataset) if target in desired_labels]
def __getitem__(self, index):
return self.dataset[self.indices[index]]
def __len__(self):
return len(self.indices)
1条答案
按热度按时间j8ag8udp1#
你的方法是一个很好的方法。你也可以定义你想要保留的标签,并创建一个只包含所需标签的新数据集对象。