这个问题是关于访问Tensor中的单个元素,比如说1,2,3。我需要访问内部元素[1,2,3](这可以使用.eval()或sess.run()来执行),但当Tensor的大小很大时,需要更长的时间)有没有什么方法可以做得更快?
dgsult0t1#
有两种主要方法可以访问Tensor中元素的子集,其中任何一种都适用于您的示例。1.使用索引运算符(基于tf.slice())从Tensor中提取连续切片。
tf.slice()
input = tf.constant([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) output = input[0, :] print sess.run(output) # ==> [1 2 3]
索引运算符支持许多与NumPy相同的切片规范。1.使用tf.gather()操作从Tensor中选择非连续切片。
tf.gather()
input = tf.constant([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) output = tf.gather(input, 0) print sess.run(output) # ==> [1 2 3] output = tf.gather(input, [0, 2]) print sess.run(output) # ==> [[1 2 3] [7 8 9]]
注意,tf.gather()只允许选择第0维中的整个切片(在矩阵示例中是整个行),因此可能需要对输入执行tf.reshape()或tf.transpose()操作以获得适当的元素。
tf.reshape()
tf.transpose()
3pvhb19x2#
我希望我已经很好地理解了您的问题。您可以通过.numpy()访问TensorFlow 2中Tensor的元素。
.numpy()
import tensorflow as tf t = tf.constant([[1,2,3]]) print(t.numpy()[0][1]) # This will print 2
>>> 2
flvtvl503#
我怀疑是计算的其余部分花费了时间,而不是访问一个元素。同样,结果可能需要从内存中拷贝,所以如果它在显卡上,则需要先将其拷贝回RAM,然后才能访问元素。如果是这种情况,您可以跳过它,添加一个张流操作来获取第一个元素,并只返回该元素。
yrwegjxp4#
不运行就无法获得1,2,3的第0个元素的值()-命名或求值()-ing一个操作,这个操作将获取它。因为在“run”或“eval”之前,你只有一个如何获取这个内部元素的描述(因为TF使用符号图形/计算)。所以即使你会使用tf.gather/tf.slice,你仍然需要通过eval/run来获得这些操作的值**,参见@mrry的答案。
4条答案
按热度按时间dgsult0t1#
有两种主要方法可以访问Tensor中元素的子集,其中任何一种都适用于您的示例。
1.使用索引运算符(基于
tf.slice()
)从Tensor中提取连续切片。索引运算符支持许多与NumPy相同的切片规范。
1.使用
tf.gather()
操作从Tensor中选择非连续切片。注意,
tf.gather()
只允许选择第0维中的整个切片(在矩阵示例中是整个行),因此可能需要对输入执行tf.reshape()
或tf.transpose()
操作以获得适当的元素。3pvhb19x2#
我希望我已经很好地理解了您的问题。您可以通过
.numpy()
访问TensorFlow 2中Tensor的元素。flvtvl503#
我怀疑是计算的其余部分花费了时间,而不是访问一个元素。
同样,结果可能需要从内存中拷贝,所以如果它在显卡上,则需要先将其拷贝回RAM,然后才能访问元素。如果是这种情况,您可以跳过它,添加一个张流操作来获取第一个元素,并只返回该元素。
yrwegjxp4#
不运行就无法获得1,2,3的第0个元素的值()-命名或求值()-ing一个操作,这个操作将获取它。因为在“run”或“eval”之前,你只有一个如何获取这个内部元素的描述(因为TF使用符号图形/计算)。所以即使你会使用tf.gather/tf.slice,你仍然需要通过eval/run来获得这些操作的值**,参见@mrry的答案。