我在Pytorch中计算了一个流场,我如何使用它生成一个重新Map的图像?

egdjgwm8  于 2023-06-06  发布在  其他
关注(0)|答案(1)|浏览(125)

我在pytorch中使用了RAFT模型来计算两帧之间的光流。代码如下:

noise_img = noise_img.to(device)
    clean_img = clean_img.to(device)
    return_index = noise_img.size(1) // 2
    aligned_frames = torch.zeros((noise_img.size(0), noise_img.size(1), noise_img.size(2), noise_img.size(3), noise_img.size(4)))
    aligned_frames[:, return_index, :, :, :] = noise_img[:, return_index, :, :, :]

    for idx in range(noise_img.size(1)):
            if not idx == return_index:
                    curr_frame = noise_img[:, idx, :, :, :]
                    ref_frame = noise_img[:, return_index, :, :, :]
                    curr_transf, ref_transf = transforms(curr_frame, ref_frame)
                    curr_flow = mc_model(curr_transf, ref_transf)[-1] # Take the final flow prediction
                    aligned_frames[:, idx, :, :, :] = align_frames(curr_transf, curr_flow)

在上面的例子中,我通过mc_model(RAFT)传递了两个帧,以返回一个光流图。在最后一行中,我试图将当前帧Map为与参考帧对齐。下面是我使用的函数:

def warp_flow(img, flow):
    flow_permute = torch.permute(flow, (0, 2, 3, 1))
    remapped = torch.nn.functional.grid_sample(img, flow_permute)
    return remapped

不幸的是,remapped在保存为图像时,不会返回连贯的图像。大多数图像为零,有些看起来像明亮的波浪。我在使用curr_flow时遗漏了一个步骤,但我不太明白是什么。
谢谢你。

v1l68za4

v1l68za41#

如果我没记错的话,RAFT以像素为单位输出偏移量,但torch.nn.functional.grid_sample采用[-1,1]中的归一化图像坐标。基本上,您需要使用torch.meshgrid来生成像素坐标,将RAFT生成的流添加到它,并将其归一化为[-1,1]。这应该用作grid_sample的输入。

相关问题