pytorch ' Torch .收集'无广播

yyyllmsg  于 2022-11-09  发布在  其他
关注(0)|答案(1)|浏览(235)

我有一些形状[batch, time, feature]的批处理输入x,以及形状[batch, new_time]的批处理索引i,我想将它们聚集到x的时间维中。作为此操作的输出,我需要形状[batch, new_time, feature]的Tensory,其值如下:

y[b, t', f] = x[b, i[b, t'], f]

在Tensorflow中,我可以通过使用batch_dims: int argument of tf.gather来实现这一点:y = tf.gather(x, i, axis=1, batch_dims=1) .
在PyTorch中,我可以想到一些做类似事情的函数:
1.当然是torch.gather,但它没有类似于Tensorflow的batch_dims的参数。torch.gather的输出将始终具有与索引相同的形状。因此,在将feature dim传递到torch.gather之前,我需要将其取消广播到i

  1. torch.index_select,但这里的索引必须是一维的。因此,要使它工作,我需要取消广播x,以添加一个“batch * new_time“dim,然后在torch.index_select之后重新整形输出。
    1.这里,嵌入矩阵对应于x,但是这个嵌入函数不支持对权重进行批处理,所以我遇到了与torch.index_select相同的问题(看代码,tf.embedding在幕后使用了torch.index_select)。

是否可以在不依赖取消广播的情况下完成此类收集操作,取消广播对于大型dims而言效率低下?

mnemlml8

mnemlml81#

这实际上是最常见的情况:当输入Tensor和索引Tensor的维数不完全匹配时,仍然可以使用torch.gather,因为可以重写表达式:

y[b, t, f] = x[b, i[b, t], f]

作为:

y[b, t, f] = x[b, i[b, t, f], f]

这就保证了所有三个Tensor都有相同的维数。这揭示了i上的第三维,我们可以通过解压缩一个维度并将其扩展为x的形状来轻松地创建 for free。您可以对i[:,None].expand_as(x)执行此操作。
下面是一个最简单的例子:

>>> b = 2; t = 3; f = 1
>>> x = torch.rand(b, t, f)
>>> i = torch.randint(0, t, (b, f))

>>> x.gather(1, i[:,None].expand_as(x))

相关问题