python-3.x default_collate:TypeError:批处理必须包含Tensor、数字、字典或列表;找到〈class 'pathlib. PosixPath'>& KeyError:0

sqougxex  于 2023-04-08  发布在  Python
关注(0)|答案(1)|浏览(293)

目标:迭代dataloader,访问torch.Tensor对象data['image']进行预测,如下所示:

for data in dataloader:
    image, slide, filename = data['image'], data['slide_id'], data['filename']
    # predict

我怀疑问题出在ApplicationDatasetcollate()方法中。
有2个错误,由collate()引起:

  1. TypeError: default_collate: batch must contain tensors, numpy arrays, numbers, dicts or lists; found <class 'pathlib.PosixPath'>
  2. KeyError: 0
    TypeErrortile_filenames: List[Path]引起。
def get_dataloader(slide_ids: List[str], tile_filenames: List[Path]) -> DataLoader:
    dataset = ApplicationDataset(slide_ids, tile_filenames)

    return DataLoader(dataset, batch_size=1, shuffle=False, num_workers=1, collate_fn=ApplicationDataset.collate)

ApplicationDataset类:

from pathlib import Path
from typing import List

import torch
from torch.utils.data import Dataset
from torchvision.io import read_image

class ApplicationDataset(Dataset):
    def __init__(self, slide_ids: List[str], tile_filenames: List[Path]):
        self.slide_ids = slide_ids
        self.tile_filenames = tile_filenames

    def __len__(self):
        return len(self.tile_filenames)

    def __getitem__(self, idx):
        image = read_image(str(self.tile_filenames[idx]))
        return {
            'image': image,
            'slide_id': self.slide_ids[idx],
            'filename': self.tile_filenames[idx],
        }

    @staticmethod
    def collate(batch):
        images = [batch_item['image'] for batch_item in batch]

        images = torch.stack(images, dim=0)
        slide_ids = torch.tensor([batch_item['slide_id'] for batch_item in batch])
        filenames = [str(batch_item['filename']) for batch_item in batch]

        return images, slide_ids, filenames

回溯:

(venv) me@laptop:~/BitBucket/project$ python app/container/application.py 
Traceback (most recent call last):
  File "/home/me/BitBucket/project/app/container/application.py", line 89, in <module>
    setup_inference(file_path_params, tile_params, fast_ai_params, dataloader)
  File "/home/me/BitBucket/project/app/container/application.py", line 65, in setup_inference
    predictions = predict_tiles(file_path_params, tile_params, dataloader, model)
  File "/home/me/BitBucket/project/app/container/model_code/predict.py", line 58, in predict_tiles
    grouped_tile_images = group_tile_images(dataloader)
  File "/home/me/BitBucket/project/app/container/model_code/predict.py", line 40, in group_tile_images
    for data in dataloader:
  File "/home/me/miniconda3/envs/venv/lib/python3.9/site-packages/fastai/data/load.py", line 127, in __iter__
    for b in _loaders[self.fake_l.num_workers==0](self.fake_l):
  File "/home/me/miniconda3/envs/venv/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 435, in __next__
    data = self._next_data()
  File "/home/me/miniconda3/envs/venv/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1085, in _next_data
    return self._process_data(data)
  File "/home/me/miniconda3/envs/venv/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1111, in _process_data
    data.reraise()
  File "/home/me/miniconda3/envs/venv/lib/python3.9/site-packages/torch/_utils.py", line 428, in reraise
    raise self.exc_type(msg)
KeyError: Caught KeyError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/home/me/miniconda3/envs/venv/lib/python3.9/site-packages/fastai/data/load.py", line 164, in create_batch
    try: return (fa_collate,fa_convert)[self.prebatched](b)
  File "/home/me/miniconda3/envs/venv/lib/python3.9/site-packages/fastai/data/load.py", line 51, in fa_collate
    return (default_collate(t) if isinstance(b, _collate_types)
  File "/home/me/miniconda3/envs/venv/lib/python3.9/site-packages/torch/utils/data/_utils/collate.py", line 73, in default_collate
    return {key: default_collate([d[key] for d in batch]) for key in elem}
  File "/home/me/miniconda3/envs/venv/lib/python3.9/site-packages/torch/utils/data/_utils/collate.py", line 73, in <dictcomp>
    return {key: default_collate([d[key] for d in batch]) for key in elem}
  File "/home/me/miniconda3/envs/venv/lib/python3.9/site-packages/torch/utils/data/_utils/collate.py", line 85, in default_collate
    raise TypeError(default_collate_err_msg_format.format(elem_type))
TypeError: default_collate: batch must contain tensors, numpy arrays, numbers, dicts or lists; found <class 'pathlib.PosixPath'>

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/home/me/miniconda3/envs/venv/lib/python3.9/site-packages/torch/utils/data/_utils/worker.py", line 198, in _worker_loop
    data = fetcher.fetch(index)
  File "/home/me/miniconda3/envs/venv/lib/python3.9/site-packages/torch/utils/data/_utils/fetch.py", line 34, in fetch
    data = next(self.dataset_iter)
  File "/home/me/miniconda3/envs/venv/lib/python3.9/site-packages/fastai/data/load.py", line 138, in create_batches
    yield from map(self.do_batch, self.chunkify(res))
  File "/home/me/miniconda3/envs/venv/lib/python3.9/site-packages/fastai/data/load.py", line 168, in do_batch
    def do_batch(self, b): return self.retain(self.create_batch(self.before_batch(b)), b)
  File "/home/me/miniconda3/envs/venv/lib/python3.9/site-packages/fastai/data/load.py", line 166, in create_batch
    if not self.prebatched: collate_error(e,b)
  File "/home/me/miniconda3/envs/venv/lib/python3.9/site-packages/fastai/data/load.py", line 75, in collate_error
    if i == 0: shape_a, type_a  = item[idx].shape, item[idx].__class__.__name__
KeyError: 0

让我知道我还应该提供哪些细节。

vlju58qv

vlju58qv1#

问题出在__getitem__上。它在字典中返回了Path值,这是不可接受的。您可以将Path转换为字符串并返回。

def __getitem__(self, idx):
        image = read_image(str(self.tile_filenames[idx]))
        assert len(image.shape) == 3 and tuple(image.shape[1:]) == (256, 256)
        return {
            'image': image,
            'slide_id': self.slide_ids[idx],
            'filename': self.tile_filenames[idx].__str__(),
        }

如果在从dataloader获取批处理后需要Path

for data in dataloader:
    image, slide, filename = data['image'], data['slide_id'], data['filename']
    # if filename is a list
    filename = [Path(file) for file in filename]
    # else if it is a string
    # filename = Path(filename)

相关问题