pytorch 如何从MNIST数据集中仅生成3,8个样本?

64jmpszr  于 2022-12-18  发布在  其他
关注(0)|答案(1)|浏览(142)

我想在MNIST的“3”和“8”个样本上训练我的模型,如何仅生成这些样本?
我试过:

all_train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
 all_train_loader = torch.utils.data.DataLoader(dataset=all_train_dataset,
                                                batch_size=batch_size,
                                                shuffle=True)

我不知道该如何继续下去。

jv2fixgn

jv2fixgn1#

您可以通过以下方式修改数据集

all_train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
selection = torch.logical_or(all_train_dataset.targets == 3, all_train_dataset.targets == 8)
all_train_dataset.data = all_train_dataset.data[selection]
all_train_dataset.targets = all_train_dataset.targets[selection]

相关问题