tensorflow Python:访问Tensor中的单个元素

rqenqsqc  于 2022-12-28  发布在  Python
关注(0)|答案(4)|浏览(333)

这个问题是关于访问Tensor中的单个元素,比如说1,2,3。我需要访问内部元素[1,2,3](这可以使用.eval()或sess.run()来执行),但当Tensor的大小很大时,需要更长的时间)
有没有什么方法可以做得更快?

dgsult0t

dgsult0t1#

有两种主要方法可以访问Tensor中元素的子集,其中任何一种都适用于您的示例。
1.使用索引运算符(基于tf.slice())从Tensor中提取连续切片。

input = tf.constant([[1, 2, 3], [4, 5, 6], [7, 8, 9]])

output = input[0, :]
print sess.run(output)  # ==> [1 2 3]

索引运算符支持许多与NumPy相同的切片规范。
1.使用tf.gather()操作从Tensor中选择非连续切片。

input = tf.constant([[1, 2, 3], [4, 5, 6], [7, 8, 9]])

output = tf.gather(input, 0)
print sess.run(output)  # ==> [1 2 3]

output = tf.gather(input, [0, 2])
print sess.run(output)  # ==> [[1 2 3] [7 8 9]]

注意,tf.gather()只允许选择第0维中的整个切片(在矩阵示例中是整个行),因此可能需要对输入执行tf.reshape()tf.transpose()操作以获得适当的元素。

3pvhb19x

3pvhb19x2#

我希望我已经很好地理解了您的问题。您可以通过.numpy()访问TensorFlow 2中Tensor的元素。

import tensorflow as tf
t = tf.constant([[1,2,3]])

print(t.numpy()[0][1]) # This will print 2
>>> 2
flvtvl50

flvtvl503#

我怀疑是计算的其余部分花费了时间,而不是访问一个元素。
同样,结果可能需要从内存中拷贝,所以如果它在显卡上,则需要先将其拷贝回RAM,然后才能访问元素。如果是这种情况,您可以跳过它,添加一个张流操作来获取第一个元素,并只返回该元素。

yrwegjxp

yrwegjxp4#

不运行就无法获得1,2,3的第0个元素的值()-命名或求值()-ing一个操作,这个操作将获取它。因为在“run”或“eval”之前,你只有一个如何获取这个内部元素的描述(因为TF使用符号图形/计算)。所以即使你会使用tf.gather/tf.slice,你仍然需要通过eval/run来获得这些操作的值**,参见@mrry的答案。

相关问题