假设我有下面的Tensor:
t = tf.convert_to_tensor([
[1,2,3,4],
[5,6,7,8]
])
还有另一个指数Tensor
i = tf.convert_to_tensor([[0],[2]])
我如何 * 收集 * 那些元素,说明[0]
引用第一个数组,[2]
引用第二个数组?从而得到结果[[1],[7]]
?
我想用一个增量值连接索引,得到[[0,0],[1,2]]
,如下所示:
i = tf.concat((tf.range(i.shape[0])[...,None] , i), axis=-1)
tf.gather_nd(t, i)
但我觉得有更好的解决办法
2条答案
按热度按时间s3fp2yjn1#
您可以使用NumPy的
take_along_axis
的TensorFlow变体,kninwzqo2#
您可以简单地将
i
与tf.range(...)
堆叠,如下所示我不确定有没有更好的解决办法。