tensorflow 如何收集每行一个元素

1aaf6o9v  于 2022-12-23  发布在  其他
关注(0)|答案(2)|浏览(135)

假设我有下面的Tensor:

t = tf.convert_to_tensor([
  [1,2,3,4],
  [5,6,7,8]
])

还有另一个指数Tensor

i = tf.convert_to_tensor([[0],[2]])

我如何 * 收集 * 那些元素,说明[0]引用第一个数组,[2]引用第二个数组?从而得到结果[[1],[7]]
我想用一个增量值连接索引,得到[[0,0],[1,2]],如下所示:

i = tf.concat((tf.range(i.shape[0])[...,None] , i), axis=-1)
tf.gather_nd(t, i)

但我觉得有更好的解决办法

s3fp2yjn

s3fp2yjn1#

您可以使用NumPy的take_along_axis的TensorFlow变体,

tf.experimental.numpy.take_along_axis(t, i, axis=1)
kninwzqo

kninwzqo2#

您可以简单地将itf.range(...)堆叠,如下所示

import tensorflow as tf

t = tf.convert_to_tensor([
  [1,2,3,4],
  [5,6,7,8]
])
i = tf.convert_to_tensor([0, 2])

length = tf.shape(i)[0]
indices = tf.stack([tf.range(length), i], axis=1)
# [0, 0], [1, 2]]

tf.gather_nd(t, indices)
# [1, 7]

我不确定有没有更好的解决办法。

相关问题