pytorch dataset.Dataset.set_transform()似乎不对图像应用转换

eh57zj3b  于 2023-10-20  发布在  其他
关注(0)|答案(1)|浏览(116)

我有一些非常简单的代码,我从HuggingFace数据集库下载一个数据集,尝试调整它的大小,并从中创建一个数据加载器。然而,我得到错误消息,图像无法堆叠,因为它们仍然是数据集中的原始大小,尽管应用了resize()函数:

from torchvision.transforms import ToTensor, Compose, Resize
from torch.utils.data import DataLoader
from datasets import load_dataset

dataset_name = 'food101'
size = 255

resize = Compose([Resize(size),ToTensor()])

dataset = load_dataset(dataset_name, split='train')
dataset.set_transform(resize)
dataset.set_format('torch')
dataloader = DataLoader(dataset, batch_size=32)

for batch in dataloader:
  print(inputs)

我得到以下错误:
运行时错误:堆栈期望每个Tensor大小相等,但在条目0处得到[512,384,3],在条目1处得到[512,512,3
我非常困惑。无论我使用set_transform()还是with_transform(),似乎都没有真正应用转换。我做错了什么?
我也试着用这样一个函数来应用它,但没有什么区别:

def transform(examples):
  examples['image'] = [resize(img) for img in examples['image']]
  return examples

dataset.set_transform(transform)
zazmityj

zazmityj1#

首先,根据数据集文档,dataset.set_format方法重置转换。所以,既然你是在resize变换中将图像变换为PytorchTensor,我相信不需要set_format。(但您仍然可以在**set_transform之前应用它,以确保)
其次,如果图像的高度和长度大小都不同,则应该为Resize((size, size))转换提供这两个维度。
总的来说,这将起作用:

resize = Compose([Resize((size, size)),ToTensor()])

def transform(examples):
  examples['image'] = [resize(img) for img in examples['image']]
  return examples

# dataset.set_format('torch')
dataset.set_transform(transform)
dataloader = DataLoader(dataset, batch_size=32)

请注意,仍然需要transform(examples)函数。

for batch in dataloader:
    print(batch.keys())
    print(batch['image'].shape)
    print(batch['label'].shape)
    break
>>>
dict_keys(['image', 'label'])
torch.Size([32, 3, 255, 255])
torch.Size([32])

相关问题