Paddle 静态图模式下Dataset的tensor被自动转换为variable,导致DataLoader报错

vktxenjb  于 5个月前  发布在  其他
关注(0)|答案(1)|浏览(44)

静态图模式下Dataset的tensor被自动转换为variable,导致DataLoader报错

windows11
paddle:2.6.1
python:3.8.19
cuda:11.8.88
cuDNN:8.6

定义了DataSet:

class MyDataSet(Dataset):
    def __init__(self, data_paths, transform=None):
        super(MyDataSet, self).__init__()
        self.data_list = []
        self.data = []
        for data_dir, label_path in data_paths:
            label_path = os.path.join(data_dir, label_path)

            with open(label_path, encoding='utf-8') as f:

                data_set = json.loads(f.read())
                for data in data_set:
                    # print(data)
                    image_name = data["img_path"]
                    label = data["state"]
                    image_path = os.path.join(data_dir, image_name)
                    self.data_list.append([image_path, label])
        self.transform = transform
        self.flag_load_all = False

    def load_alldata(self):
        if not self.flag_load_all:
            for image_path, label in self.data_list:
                img = Image.open(image_path)
                if img.mode != 'RGB':
                    img = img.convert('RGB')
                image = img.resize((128, 128), Image.Resampling.LANCZOS)
                image = np.array(image).astype(np.float32)
                self.data.append([image, label])
            self.flag_load_all = True

    def __getitem__(self, index):
        image = None
        label = None
        if self.flag_load_all:
            image, label = self.data[index]

        else:
            image_path, label = self.data_list[index]
            image = Image.open(image_path)
            if image.mode != 'RGB':
                image = image.convert('RGB')
            image = image.resize((128, 128), Image.Resampling.LANCZOS)

            image = np.array(image).astype(np.float32)

        if self.transform is not None:
            image = self.transform(image)
        print(image)
        label = np.array(label[1:])
        if id == 4:
            label = 0 - label
        label = to_tensor(label, dtype="float32")
        # 返回图像和对应标签
        return image, label

    def __len__(self):
        return len(self.data_list)

如下建立数据集

transform4data = Compose([Normalize(mean=[127.5], std=[127.5]),  ToTensor()])
train_custom_dataset = MyDataSet([["./data/image_set","data.json"], ], transform=transform4data)
print("read all data to memory")
train_custom_dataset.load_alldata()
train_loader = DataLoader(train_custom_dataset, batch_size=256, shuffle=True, drop_last=False, num_workers=0)

如果不启用 paddle.enable_static() 则一切正常,但一旦使用 paddle.enable_static()
数据集里的tensor数据类型会被自动变换为Variable,尝试使用data_loader

for data in train_loader:
    print(data)

就会导致

Exception in thread Thread-1:
Traceback (most recent call last):
  File "D:\Anaconda\envs\smart-car\lib\threading.py", line 932, in _bootstrap_inner
    self.run()
  File "D:\Anaconda\envs\smart-car\lib\threading.py", line 870, in run
    self._target(*self._args, **self._kwargs)
  File "D:\Anaconda\envs\smart-car\lib\site-packages\paddle\io\dataloader\dataloader_iter.py", line 235, in _thread_loop
    batch = self._dataset_fetcher.fetch(
  File "D:\Anaconda\envs\smart-car\lib\site-packages\paddle\io\dataloader\fetcher.py", line 85, in fetch
    data = self.collate_fn(data)
  File "D:\Anaconda\envs\smart-car\lib\site-packages\paddle\io\dataloader\collate.py", line 75, in default_collate_fn
    return [default_collate_fn(fields) for fields in zip(*batch)]
  File "D:\Anaconda\envs\smart-car\lib\site-packages\paddle\io\dataloader\collate.py", line 75, in <listcomp>
    return [default_collate_fn(fields) for fields in zip(*batch)]
  File "D:\Anaconda\envs\smart-car\lib\site-packages\paddle\io\dataloader\collate.py", line 77, in default_collate_fn
    raise TypeError(
TypeError: batch data con only contains: tensor, numpy.ndarray, dict, list, number, but got <class 'paddle.base.framework.Variable'>

请问该如何解决上述问题

nuypyhwy

nuypyhwy1#

数据集 __getitem__ 返回 numpy.ndarray 即可,不需要把 imagelabel 手动转换为 paddle.Tensor

相关问题