PyTorch:如何使用 Torch 视觉,转换,8月混合与 Torch ,float32?

qlzsbp2j  于 2022-11-09  发布在  其他
关注(0)|答案(2)|浏览(301)

PyTorch:如何使用 Torch 视觉。转换。8月混合与 Torch 。float 32?
我正在尝试使用torchvision.transforms.AugMix在影像数据集中应用数据扩充,但出现以下错误:TypeError:只支援torch.uint8影像Tensor,但找到torch.float32。我尝试将它转换成int,但发生另一个错误。
我尝试使用AugMix函数的代码:

transform = torchvision.transforms.Compose(
    [
        torchvision.transforms.Resize((224, 224)),  # resize to 224*224
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),  # normalization
        torchvision.transforms.AugMix()
    ]
)
to_tensor = torchvision.transforms.ToTensor()
Image.MAX_IMAGE_PIXELS = None

class BreastDataset(torch.utils.data.Dataset):

    def __init__(self, json_path, data_dir_path='./dataset', clinical_data_path=None, is_preloading=True):
        self.data_dir_path = data_dir_path
        self.is_preloading = is_preloading

        with open(json_path) as f:
            print(f"load data from {json_path}")
            self.json_data = json.load(f)

    def __len__(self):
        return len(self.json_data)

    def __getitem__(self, index):
        label = int(self.json_data[index]["label"])
        patient_id = self.json_data[index]["id"]
        patch_paths = self.json_data[index]["patch_paths"]

        data = {}
        if self.is_preloading:
            data["bag_tensor"] = self.bag_tensor_list[index]
        else:
            data["bag_tensor"] = self.load_bag_tensor([os.path.join(self.data_dir_path, p_path) for p_path in patch_paths])

        data["label"] = label
        data["patient_id"] = patient_id
        data["patch_paths"] = patch_paths

        return data

    def load_bag_tensor(self, patch_paths):
        """Load a bag data as tensor with shape [N, C, H, W]"""

        patch_tensor_list = []
        for p_path in patch_paths:
            patch = Image.open(p_path).convert("RGB")
            patch_tensor = transform(patch)  # [C, H, W]
            patch_tensor = torch.unsqueeze(patch_tensor, dim=0)  # [1, C, H, W]
            patch_tensor_list.append(patch_tensor)

        bag_tensor = torch.cat(patch_tensor_list, dim=0)  # [N, C, H, W]

        return bag_tensor

任何帮助都是感激不尽的!提前谢谢你!

dy2hfwbg

dy2hfwbg1#

对我来说,首先应用AugMix,然后ToTensor()工作

transformation = transforms.Compose([
                    transforms.AugMix(severity= 6,mixture_width=2),
                    transforms.ToTensor(),
                    transforms.RandomErasing(),
                    transforms.RandomGrayscale(p = 0.35)
                    ])
yzuktlbb

yzuktlbb2#

torchvision.transforms.AugMixuint8拍摄图像。这意味着每个像素都是1(灰度)或3(rgb)0到255之间的数字,这是图像的经典格式。
torch.Tensor.type(torch.float32)uint8Tensor转换为float32,但这不太可能是应用于图像的单一变换。float32图像通常被归一化到[-1,1]或[0,1]范围内。常用的方法是:

img = img.type(torch.float32) / 128.0 - 1.0  # [-1, 1]
img = img.type(torch.float32) / 255.0  # [0, 1]

当你知道你是在什么情况下,你可以重铸为uint8

img = (img + 1.0) * 128.0  # case [-1, 1]
img = img * 255.0  # case [0, 1]
img = torch.clip(img, 0.0, 255.0)
img = img.type(torch.uint8)

相关问题