我最近在pytorch中遇到了这个问题,当我使用4DTensor时,它应该被3DTensor索引。
假设我们有这个4DTensor:
possible_values.size()
torch.Size([2, 5, 5, 4])
其中:
dim 1 = batch
dim 2 = x_axis
dim 3 = y_axis
dim 4 = possible values of coordinate (x_i,y_j)
然后,我们有一个3D“索引”Tensor,它应该用于基于x和y坐标选择dim 4的值:
coordinates.size()
torch.Size([2, 5, 2])
型
其中:
dim 1 = batch
dim 2 = sequences of (x,y)
dim 3 = (x,y) coordinate
例如,coordinates
看起来像
[ [ [1,5] [3,3] [2,4] [1,3] [2,3] ]
[ [1,5] [4,3] [2,1] [5,3] [5,3] ] ]
我们要做的是从一个批次中选择coordinates
指定的坐标的可能值。因此,我们要从第一个批次中选择坐标[1, 5]
、[3, 3]
处的4
值,依此类推。
我已经看了一些index_select
和gather
,但是目前还不能理解它(或者让它大致做我想要的事情)。
谢谢。
2条答案
按热度按时间0vvn1miw1#
好的,让我们从删除批维度开始:
上面给出了单个batch元素的正确值,现在我们需要一种方法来为i的所有值广播此操作(即跨batch维)。
这基本上是正确的,但是它是“过度广播”的(即,它返回每个批处理元素的期望索引,对于每个批处理元素”)。现在,我们需要仅对前2个维度上的主对角线元素进行索引,以便我们获得每个批处理元素的期望索引,对于每个批处理元素:
型
这种解决方案还存在一些不足之处,因为它没有在不进行修改的情况下扩展到任意多个维度(即,如果添加了z轴,则必须向块添加额外的
coordinates[:,:,2]
索引,依此类推)。dvtswwa32#
我认为您正在寻找
torch.nn.functional.grid_sample
。您确实需要稍微修改您的输入,但我希望它能起作用: