我有一些形状[batch, time, feature]
的批处理输入x
,以及形状[batch, new_time]
的批处理索引i
,我想将它们聚集到x
的时间维中。作为此操作的输出,我需要形状[batch, new_time, feature]
的Tensory
,其值如下:
y[b, t', f] = x[b, i[b, t'], f]
在Tensorflow中,我可以通过使用batch_dims: int
argument of tf.gather
来实现这一点:y = tf.gather(x, i, axis=1, batch_dims=1)
.
在PyTorch中,我可以想到一些做类似事情的函数:
1.当然是torch.gather
,但它没有类似于Tensorflow的batch_dims
的参数。torch.gather
的输出将始终具有与索引相同的形状。因此,在将feature
dim传递到torch.gather
之前,我需要将其取消广播到i
。
torch.index_select
,但这里的索引必须是一维的。因此,要使它工作,我需要取消广播x
,以添加一个“batch * new_time
“dim,然后在torch.index_select
之后重新整形输出。
1.这里,嵌入矩阵对应于x
,但是这个嵌入函数不支持对权重进行批处理,所以我遇到了与torch.index_select
相同的问题(看代码,tf.embedding
在幕后使用了torch.index_select
)。
是否可以在不依赖取消广播的情况下完成此类收集操作,取消广播对于大型dims而言效率低下?
1条答案
按热度按时间mnemlml81#
这实际上是最常见的情况:当输入Tensor和索引Tensor的维数不完全匹配时,仍然可以使用
torch.gather
,因为可以重写表达式:作为:
这就保证了所有三个Tensor都有相同的维数。这揭示了
i
上的第三维,我们可以通过解压缩一个维度并将其扩展为x
的形状来轻松地创建 for free。您可以对i[:,None].expand_as(x)
执行此操作。下面是一个最简单的例子: