Pytorch数据加载器不能迭代图像文件夹

mzsu5hc0  于 2023-03-02  发布在  其他
关注(0)|答案(1)|浏览(130)

我正在尝试加载数据集https://github.com/jaddoescad/ants_and_bees
但是,当我尝试迭代数据加载器时出现错误

training_dataset = datasets.ImageFolder('ants_and_bees/train', transform=transform_train)
validation_dataset = datasets.ImageFolder('ants_and_bees/val', transform=transform)

training_loader = torch.utils.data.DataLoader(training_dataset, batch_size=20, shuffle=True)
validation_loader = torch.utils.data.DataLoader(validation_dataset, batch_size = 20, shuffle=False)

def im_convert(tensor):
  image = tensor.cpu().clone().detach().numpy()
  image = image.transpose(1, 2, 0)
  image = image * np.array((0.5, 0.5, 0.5)) + np.array((0.5, 0.5, 0.5))
  image = image.clip(0, 1)
  return image

classes = ('ant', 'bee')

dataiter = iter(training_loader)
images, labels = next(dataiter)
fig = plt.figure(figsize=(25, 4))

for idx in np.arange(20):
  ax = fig.add_subplot(2, 10, idx+1, xticks=[], yticks=[])
  plt.imshow(im_convert(images[idx]))
  ax.set_title(classes[labels[idx].item()])

错误消息没有多大帮助,我在这里读到一些类似的问题,但无法找到解决方案。

---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
<ipython-input-58-fb882084a0d1> in <module>
      1 dataiter = iter(training_loader)
----> 2 images, labels = next(dataiter)
      3 fig = plt.figure(figsize=(25, 4))
      4 
      5 for idx in np.arange(20):

10 frames
/usr/local/lib/python3.8/dist-packages/PIL/TgaImagePlugin.py in _open(self)
     64         flags = i8(s[17])
     65 
---> 66         self.size = i16(s[12:]), i16(s[14:])
     67 
     68         # validate header fields

AttributeError: can't set attribute

代码来自Pytorch教程https://github.com/rslim087a/PyTorch-for-Deep-Learning-and-Computer-Vision-Course-All-Codes-/blob/master/PyTorch%20for%20Deep%20Learning%20and%20Computer%20Vision%20Course%20(All%20Codes)/Transfer_Learning.ipynb
我在谷歌协作上运行。
OBS:这似乎是Colab的问题,或者是那里的Python版本,我能够用Python 3. 9. 13环境在本地运行。

pxiryf3j

pxiryf3j1#

请将此添加到您的代码:

transform_train = transforms.Compose([
   transforms.ToTensor()

])

相关问题