Pytorch:ValueError:维度太多:3>2. 9/opt/anaconda3/envs/pytorch/lib/python3.7/site-packages/PIL/Image.pyin fromarray(obj,mode)

aor9mmx1  于 2023-04-21  发布在  Python
关注(0)|答案(2)|浏览(414)

我使用MNIST数据来运行pytorch的python。我喜欢只训练数字0和1的部分数据。当我尝试打印第一个图像的大小时,它会遇到以下错误:
ValueError:维度太多:3〉2
我对Python很陌生。如果我不对训练数据进行分段,程序运行得很好。下面是代码片段

subset_indices = ((train_data.train_labels == 0) + (train_data.train_labels == 1)).nonzero()
train_loader = torch.utils.data.DataLoader(train_data,batch_size=batch_size, shuffle=False,sampler=SubsetRandomSampler(subset_indices))
drkbr07n

drkbr07n1#

这个错误是由于你传递了一个三维数组到函数Image.fromarray中,这可能是在错误的模式下设置的。你需要确保mode被设置为RGB,这样它看起来就像Image.fromarray(data, mode='RGB')

pgvzfuti

pgvzfuti2#

索引只需要在1D数组中,可以使用.view(-1)完成
例如sampler = SubsetRandomSampler(subset_indices.view(-1))

相关问题