两幅图像上的相同随机裁剪Pytorch变换

qvsjd97n  于 2023-08-05  发布在  其他
关注(0)|答案(2)|浏览(108)

我试图将两个图像输入到网络中,我想在这两个图像之间进行相同的变换。transforms.Compose()一次获取一个图像,并产生彼此独立的输出,但我想要相同的转换。我做了我自己的编码为hflip()现在我有兴趣得到随机作物。有没有什么方法可以做到这一点,而无需编写自定义函数?

relj7zay

relj7zay1#

我会使用这样的解决方法-从RandomCrop继承我自己的作物类,使用

…
        if self.call_is_even :
            self.ijhw = self.get_params(img, self.size)
        i, j, h, w = self.ijhw
        self.call_is_even = not self.call_is_even

字符串
代替了

i, j, h, w = self.get_params(img, self.size)


这个想法是抑制随机发生器对奇数调用

qqrboqgw

qqrboqgw2#

PtrBlck建议使用pytorch函数API来进行您自己的转换,以实现您想要的1(https://discuss.pytorch.org/t/torchvision-transfors-how-to-perform-identical-transform-on-both-image-and-target/10606/7),但我认为在大多数情况下,有一种更干净的方法:
如果图像是 Torch Tensor,则预期其具有[...,H,W]形状,其中...表示任意数量的前导维度
-- torchvision.RandomCrop-
您可以沿着通道维度堆叠图像。(也许这甚至意味着你可以沿着一个 * 新 * 的维度堆叠图像)。通过这种方式,您可以一次性应用变换--同时应用到所有图像。

更多图像的附录

对于两个图像,此操作正常。对于图像列表,您可以执行相同的操作。但是如果你有嵌套的图像列表,这会很麻烦。
我为该用例编写了一个在嵌套列表(深度正好为2)上操作的函数。请注意,此方法仅适用于Tensor,而不适用于PIL图像,因此它首先将pil图像转换为Tensor,除非您设置它不这样做--下面的test函数中的示例:

def apply_same_transform_to_all_images(transform: torch.nn.Module, images: List[List[Image.Image]], 
                                       to_ten: torch.nn.Module = None) -> List[List[torch.Tensor]]:
    """
        applies the transform to all images in the nested list - in a single run.
        Useful e.g. for consistent RandomCrop.
        Note that this will call torchvision.ToTensor() on your images unless you set `to_ten`!
          That rescales values. It also means the passed-in transform should assume tensor inputs, not PIL inputs.
    """
    if to_ten is None: to_ten = torchvision.transforms.ToTensor()

    # traverse the list to collect all images
    all_images = list()
    for outer_idx, inner_list in enumerate(images):
        for _inner_idx, image in enumerate(inner_list):
            all_images.append(to_ten(image))
    # stack all images in a new dimension
    stacked = torch.stack(all_images, dim=0)
    transformed = transform(stacked)
    transformed_iter = iter(transformed)

    # undo the traversing in the same order.
    output_nested_list = list()
    for outer_idx, inner_list in enumerate(images):
        output_nested_list.append(list())
        for _inner_idx, _image in enumerate(inner_list):
            output_nested_list[outer_idx].append(next(transformed_iter))

    return output_nested_list

def test_apply_same_transform_to_all_images():
    import torch, utils
    from torchvision.transforms import RandomCrop, Compose
    identity_transform = Compose([]) # we already have tensors
    img1 = torch.arange(4*4*3).reshape((3,4,4)) # image of shape 4x4 with 3 channels
    img2 = img1 * 2
    crop_transform = RandomCrop((2,2))
    result = apply_same_transform_to_all_images(crop_transform, to_ten = identity_transform, images = [[img1], [img1, img2]])
    assert torch.allclose(result[0][0], result[1][0])
    assert torch.allclose(result[0][0] * 2, result[1][1])

if __name__ == "__main__":
    # just for debugging
    test_apply_same_transform_to_all_images()

字符串

相关问题