我在pytorch中使用了RAFT模型来计算两帧之间的光流。代码如下:
noise_img = noise_img.to(device)
clean_img = clean_img.to(device)
return_index = noise_img.size(1) // 2
aligned_frames = torch.zeros((noise_img.size(0), noise_img.size(1), noise_img.size(2), noise_img.size(3), noise_img.size(4)))
aligned_frames[:, return_index, :, :, :] = noise_img[:, return_index, :, :, :]
for idx in range(noise_img.size(1)):
if not idx == return_index:
curr_frame = noise_img[:, idx, :, :, :]
ref_frame = noise_img[:, return_index, :, :, :]
curr_transf, ref_transf = transforms(curr_frame, ref_frame)
curr_flow = mc_model(curr_transf, ref_transf)[-1] # Take the final flow prediction
aligned_frames[:, idx, :, :, :] = align_frames(curr_transf, curr_flow)
在上面的例子中,我通过mc_model(RAFT)传递了两个帧,以返回一个光流图。在最后一行中,我试图将当前帧Map为与参考帧对齐。下面是我使用的函数:
def warp_flow(img, flow):
flow_permute = torch.permute(flow, (0, 2, 3, 1))
remapped = torch.nn.functional.grid_sample(img, flow_permute)
return remapped
不幸的是,remapped
在保存为图像时,不会返回连贯的图像。大多数图像为零,有些看起来像明亮的波浪。我在使用curr_flow
时遗漏了一个步骤,但我不太明白是什么。
谢谢你。
1条答案
按热度按时间v1l68za41#
如果我没记错的话,RAFT以像素为单位输出偏移量,但
torch.nn.functional.grid_sample
采用[-1,1]中的归一化图像坐标。基本上,您需要使用torch.meshgrid
来生成像素坐标,将RAFT生成的流添加到它,并将其归一化为[-1,1]。这应该用作grid_sample
的输入。