pytorch 使用3dTensor从4dTensor中选择值

uinbv5nw  于 2022-12-18  发布在  其他
关注(0)|答案(2)|浏览(230)

我最近在pytorch中遇到了这个问题,当我使用4DTensor时,它应该被3DTensor索引。
假设我们有这个4DTensor:

possible_values.size()
torch.Size([2, 5, 5, 4])

其中:

dim 1 = batch
dim 2 = x_axis
dim 3 = y_axis
dim 4 = possible values of coordinate (x_i,y_j)

然后,我们有一个3D“索引”Tensor,它应该用于基于x和y坐标选择dim 4的值:

coordinates.size()
torch.Size([2, 5, 2])


其中:

dim 1 = batch
dim 2 = sequences of (x,y) 
dim 3 = (x,y) coordinate

例如,coordinates看起来像

[ [ [1,5] [3,3] [2,4] [1,3] [2,3] ]
  [ [1,5] [4,3] [2,1] [5,3] [5,3] ] ]

我们要做的是从一个批次中选择coordinates指定的坐标的可能值。因此,我们要从第一个批次中选择坐标[1, 5][3, 3]处的4值,依此类推。
我已经看了一些index_selectgather,但是目前还不能理解它(或者让它大致做我想要的事情)。
谢谢。

0vvn1miw

0vvn1miw1#

好的,让我们从删除批维度开始:

possible_values[i,coordinates[i,:,0],coordinates[i,:,1],:]  # [output is of shape [5,4]

上面给出了单个batch元素的正确值,现在我们需要一种方法来为i的所有值广播此操作(即跨batch维)。

possible_values[:,coordinates[:,:,0],coordinates[:,:,1],:]  # [output is of shape [2,2,5,4]

这基本上是正确的,但是它是“过度广播”的(即,它返回每个批处理元素的期望索引,对于每个批处理元素”)。现在,我们需要仅对前2个维度上的主对角线元素进行索引,以便我们获得每个批处理元素的期望索引,对于每个批处理元素:

batch_size = possible_values.shape[0]
batch_idx = torch.arange(batch_size)
possible_values[:,coordinates[:,:,0],coordinates[:,:,1],:][batch_size,batch_size,:,:]   # output is of shape [2,5,4]


这种解决方案还存在一些不足之处,因为它没有在不进行修改的情况下扩展到任意多个维度(即,如果添加了z轴,则必须向块添加额外的coordinates[:,:,2]索引,依此类推)。

dvtswwa3

dvtswwa32#

我认为您正在寻找torch.nn.functional.grid_sample
您确实需要稍微修改您的输入,但我希望它能起作用:

import torch.nn.functional as nnf

possible_values = possible_values.permute(0, 3, 1, 2)  # make the "channel" dimension the second one
out = nnf.grid_sample(input=possible_values, grid=coordinates[..., None, :], mode='nearest')

相关问题