python 如何删除pytorch数据集的一些标签

8gsdolmq  于 2023-04-10  发布在  Python
关注(0)|答案(1)|浏览(204)

我有一个torchvision.datasets对象。我只想保留一些标签并删除其他标签。
例如,如果我的数据集是CFAR10,像这样的trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True),我将有10个标签。我只想保留前三个标签,删除其他标签。我怎么做?
P.S:我想我可以通过像这样从头开始构建一个数据集对象来做到这一点。但我猜应该有一种更短的方法来做到这一点:

class FilteredDataset(torch.utils.data.Dataset):
    def __init__(self, dataset, desired_labels):
        self.dataset = dataset
        self.indices = [i for i, (_, target) in enumerate(self.dataset) if target in desired_labels]
        
    def __getitem__(self, index):
        return self.dataset[self.indices[index]]
    
    def __len__(self):
        return len(self.indices)
j8ag8udp

j8ag8udp1#

你的方法是一个很好的方法。你也可以定义你想要保留的标签,并创建一个只包含所需标签的新数据集对象。

desired_labels = [0, 1, 2]
trainset = datasets.CIFAR10(root='./data', train=True, download=True)

filtered_trainset = torch.utils.data.Subset(trainset, [i for i in range(len(trainset)) if trainset.targets[i] in desired_labels])

相关问题