def apply_same_transform_to_all_images(transform: torch.nn.Module, images: List[List[Image.Image]],
to_ten: torch.nn.Module = None) -> List[List[torch.Tensor]]:
"""
applies the transform to all images in the nested list - in a single run.
Useful e.g. for consistent RandomCrop.
Note that this will call torchvision.ToTensor() on your images unless you set `to_ten`!
That rescales values. It also means the passed-in transform should assume tensor inputs, not PIL inputs.
"""
if to_ten is None: to_ten = torchvision.transforms.ToTensor()
# traverse the list to collect all images
all_images = list()
for outer_idx, inner_list in enumerate(images):
for _inner_idx, image in enumerate(inner_list):
all_images.append(to_ten(image))
# stack all images in a new dimension
stacked = torch.stack(all_images, dim=0)
transformed = transform(stacked)
transformed_iter = iter(transformed)
# undo the traversing in the same order.
output_nested_list = list()
for outer_idx, inner_list in enumerate(images):
output_nested_list.append(list())
for _inner_idx, _image in enumerate(inner_list):
output_nested_list[outer_idx].append(next(transformed_iter))
return output_nested_list
def test_apply_same_transform_to_all_images():
import torch, utils
from torchvision.transforms import RandomCrop, Compose
identity_transform = Compose([]) # we already have tensors
img1 = torch.arange(4*4*3).reshape((3,4,4)) # image of shape 4x4 with 3 channels
img2 = img1 * 2
crop_transform = RandomCrop((2,2))
result = apply_same_transform_to_all_images(crop_transform, to_ten = identity_transform, images = [[img1], [img1, img2]])
assert torch.allclose(result[0][0], result[1][0])
assert torch.allclose(result[0][0] * 2, result[1][1])
if __name__ == "__main__":
# just for debugging
test_apply_same_transform_to_all_images()
2条答案
按热度按时间relj7zay1#
我会使用这样的解决方法-从RandomCrop继承我自己的作物类,使用
字符串
代替了
型
这个想法是抑制随机发生器对奇数调用
qqrboqgw2#
PtrBlck建议使用pytorch函数API来进行您自己的转换,以实现您想要的1(https://discuss.pytorch.org/t/torchvision-transfors-how-to-perform-identical-transform-on-both-image-and-target/10606/7),但我认为在大多数情况下,有一种更干净的方法:
如果图像是 Torch Tensor,则预期其具有[...,H,W]形状,其中...表示任意数量的前导维度
-- torchvision.RandomCrop-
您可以沿着通道维度堆叠图像。(也许这甚至意味着你可以沿着一个 * 新 * 的维度堆叠图像)。通过这种方式,您可以一次性应用变换--同时应用到所有图像。
更多图像的附录
对于两个图像,此操作正常。对于图像列表,您可以执行相同的操作。但是如果你有嵌套的图像列表,这会很麻烦。
我为该用例编写了一个在嵌套列表(深度正好为2)上操作的函数。请注意,此方法仅适用于Tensor,而不适用于PIL图像,因此它首先将pil图像转换为Tensor,除非您设置它不这样做--下面的test函数中的示例:
字符串