如何使用PyTorch从本地目录导入MNIST数据集

fsi0uk1n  于 2023-04-21  发布在  其他
关注(0)|答案(3)|浏览(238)

我正在PyTorch中编写一个著名问题MNIST database of handwritten digits的代码。我下载了训练和测试数据集(从主网站),包括标记数据集。数据集格式为t10k-images-idx3-ubyte.gz,提取后t10k-images-idx3-ubyte。我的数据集文件夹看起来像

MINST
 Data
  train-images-idx3-ubyte.gz
  train-labels-idx1-ubyte.gz
  t10k-images-idx3-ubyte.gz
  t10k-labels-idx1-ubyte.gz

现在,我写了一段代码来加载数据,就像贝娄一样

def load_dataset():
    data_path = "/home/MNIST/Data/"
    xy_trainPT = torchvision.datasets.ImageFolder(
        root=data_path, transform=torchvision.transforms.ToTensor()
    )
    train_loader = torch.utils.data.DataLoader(
        xy_trainPT, batch_size=64, num_workers=0, shuffle=True
    )
    return train_loader

我的代码显示Supported extensions are: .jpg,.jpeg,.png,.ppm,.bmp,.pgm,.tif,.tiff,.webp
我如何解决这个问题,我还想检查我的图像是否从数据集中加载(只是一个图包含前5个图像)?

fjnneemd

fjnneemd1#

阅读Extract images from .idx3-ubyte file or GZIP via Python

更新

可以使用此格式导入数据

xy_trainPT = torchvision.datasets.MNIST(
    root="~/Handwritten_Deep_L/",
    train=True,
    download=True,
    transform=torchvision.transforms.Compose([torchvision.transforms.ToTensor()]),
)

现在,在download=True上发生的事情首先,您的代码将检查根目录(给定路径)是否包含任何数据集。
如果no,则数据集将从Web下载。
如果yes这个路径已经包含了一个数据集,那么你的代码将使用现有的数据集,而不会从互联网上下载。
你可以检查一下,先给予一个路径without any dataset(数据会从网上下载),然后再给另一个路径which already contains dataset数据不会下载。

z31licg0

z31licg02#

欢迎来到stackoverflow!
MNIST数据集不存储为图像,而是以二进制格式存储(如ubyte扩展名所示)。因此,ImageFolder不是您想要的类型数据集。相反,您需要使用MNIST dataset class。如果您还没有这样做,它甚至可以下载数据:)
这是一个数据集类,所以只需使用正确的root路径示例化,然后将其作为dataloader的参数,一切都应该正常工作。
如果你想检查图像,只需使用dataloader的get方法,并将结果保存为png文件(你可能需要先将Tensor转换为numpy数组)。

des4xlb0

des4xlb03#

我已经下载了MNIST文件夹(通过pytorch数据集)在我的存储库的其他地方,当我需要在不同的源文件中重新下载时,我不想再次重新下载它。
我的问题是,当传递root参数时,我引用的是MNIST/文件夹,但实际上你应该引用包含MNIST/目录的父文件夹。事实上,docs提到:
root(string)-MNIST/raw/train-images-idx 3-ubyte和MNIST/raw/t10 k-images-idx 3-ubyte存在的数据集的根目录。
所以我认为应该省略路径的MNIST/部分。
所以在我的情况下,我有:mnist_dataset = torchvision.datasets.MNIST(root='../MNIST/', train=True, download=False)
希望这对你有帮助
应改为:mnist_dataset = torchvision.datasets.MNIST(root='../', train=True, download=False)

相关问题