我想在MNIST的“3”和“8”个样本上训练我的模型,如何仅生成这些样本?
我试过:
all_train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
all_train_loader = torch.utils.data.DataLoader(dataset=all_train_dataset,
batch_size=batch_size,
shuffle=True)
我不知道该如何继续下去。
1条答案
按热度按时间jv2fixgn1#
您可以通过以下方式修改数据集