使用PyTorch进行数据扩充

pftdvrlh  于 2023-06-29  发布在  其他
关注(0)|答案(1)|浏览(88)

我正在使用torchvision.transforms.Compose函数来构造一个转换对象,以将其用于数据扩充。
当我在没有转换的情况下调用BloodCellDataSet类时,它返回字典,但是当调用transform=transforms类时,它不应用转换,甚至if self.transform语句中的print语句也不起作用。
我是否正确地使用了这些功能,或者我遗漏了什么?

import torch
import h5py
import torchvision.transforms
from PIL import Image
import numpy as np

transforms = torchvision.transforms.Compose([
    torchvision.transforms.Resize((64,64))
])

class BloodCellDataSet(torch.utils.data.Dataset):
    
    def __init__(self, h5py_file, subset='train', transform=None):
        
        self.f =  h5py.File(h5py_file,'r')
        self.transform = transform
        self.subset = subset
    
    def __len__(self):
        return len(self.f['phase']['images'])
    
    def __getitem__(self, idx):
        
        sample = {}
        sample['image_amplitude'] = self.f['amplitude']['images'][idx]
        sample['image_phase'] = self.f['phase']['images'][idx]
        
        if self.subset == 'train':
            if 'label' not in self.f or 'mask' not in self.f:
                raise RuntimeError('The set doesn`t contain masks or labels.')
            
            sample['image_mask'] = torch.from_numpy(self.f['mask']['images'][idx])
            sample['image_label'] = self.f['label']['ground_truth'][idx].decode("utf-8")
            sample['image_label_decoded'] = LabelTransformer.decodeClass(self.f['label']['ground_truth'][idx].decode("utf-8"))

        # Do we have to create the onehot encoded labels here for classification?
        if self.transform:
            sample['image_amplitude'] = self.transform(sample['image_amplitude'])
            sample['image_phase'] = self.transform(sample['image_phase'])
            print(sample['image_amplitude'])
            print(sample['image_phase'])
            
        return sample

class LabelTransformer(object):
    
    @staticmethod
    def encodeClass(index) -> str:
        return { 0: 'rbc', 1: 'wbc', 2: 'plt', 3: 'agg', 4: 'oof'}[index]
    
    @staticmethod
    def decodeClass(class_name) -> int:
        return {'rbc': 0, 'wbc': 1, 'plt': 2, 'agg': 3, 'oof': 4}[class_name]
    
    @staticmethod
    def numberOfClasses() -> int:
        return 5
    
    @staticmethod
    def listOfClasses() -> list:
        return ['rbc', 'wbc', 'plt', 'agg', 'oof']
    
dataset_pred_trans = BloodCellDataSet('path', 'train', transform=transforms)
xcitsw88

xcitsw881#

__getitem__方法应用于一个元素(您使用idx参数指向的元素)。同样,您将转换应用于一个元素,而不是集合、列表或任何其他容器。所以我认为你应该只转换你的样本中的一个条目,类似这样:

your_transformed_image = self.transform(sample['image_amplitude'][idx])

这可能不会解决你所有的问题,但应该让你进步一点。

相关问题