Tensorflow tensor_scatter_nd_add
中错误的原因和机制是什么?
尝试使用tensor_scatter_nd_add
来更新Tensor中的元素,例如tf.tensor_scatter_nd_add,它使用tf.constant
来创建更新updates = tf.constant([9, 10, 11, 12])
。
indices = tf.constant([[4], [3], [1], [7]])
updates = tf.constant([9, 10, 11, 12])
tensor = tf.ones([8], dtype=tf.int32)
updated = tf.tensor_scatter_nd_add(tensor, indices, updates)
字符串
然而,下面的代码,使用tf.constant
创建更新导致错误ValueError: TypeError: Scalar tensor has no len()
。
代码摘录
def body(
_current_cell_index,
_predictions
):
_cell_index_in_current_batch = tf.cast(
_current_cell_index % num_cells_in_batch,
dtype=TYPE_INT
)
row: TYPE_FLOAT = tf.cast(
tf.math.floor(_cell_index_in_current_batch / S),
dtype=TYPE_FLOAT
)
col: TYPE_FLOAT = tf.cast(
_cell_index_in_current_batch % S,
dtype=TYPE_FLOAT
)
_predictions = tf.tensor_scatter_nd_add(
tensor=_predictions,
indices=[
[_current_cell_index, C+1], # p0_x: x in (C,(cp,x,y,w,h))
[_current_cell_index, C+2], # p0_y
[_current_cell_index, C+P+1], # p1_x
[_current_cell_index, C+P+2] # p1_y
],
updates=tf.constant([ <----- ValueError: TypeError: Scalar tensor has no `len()
col, # p0_x + col
row, # p0_y + row
col, # p0_x + col
row # p0_y + row
])
)
return [
_current_cell_index+1,
_predictions
]
型
删除tf.constant
使其工作。
_predictions = tf.tensor_scatter_nd_add(
tensor=_predictions,
indices=[
[_current_cell_index, C+1], # p0_x: x in (C,(cp,x,y,w,h))
[_current_cell_index, C+2], # p0_y
[_current_cell_index, C+P+1], # p1_x
[_current_cell_index, C+P+2] # p1_y
],
updates=[ <----- remove tf.constant make it work
col, # p0_x + col
row, # p0_y + row
col, # p0_x + col
row # p0_y + row
]
)
型
代码
TYPE_FLOAT = np.float32
TYPE_INT = np.int32
N = 2
S = 3
C = 5
B = 2
P = 5
# Tensor N number of SxS cells where each cell has (C+B*P) elements
predictions: tf.Tensor = tf.reshape(
tensor=tf.zeros(N*S*S*(C+B*P), dtype=TYPE_FLOAT),
shape=(N,S,S,(C+B*P)),
)
# Geometry of the Tensor
num_cells_in_batch = tf.constant(S*S, dtype=TYPE_INT)
num_total_cells = tf.constant(N * num_cells_in_batch, dtype=TYPE_INT)
# Loop vars for tf.while_loop
current_cell_index = tf.constant(0, dtype=TYPE_INT)
loop_vars = (
current_cell_index,
# reshape the prediction as a sequence of cells
tf.reshape(tensor=predictions, shape=(-1, C+B*P))
)
def condition(
_current_cell_index,
_predictions
):
return tf.less(_current_cell_index, num_total_cells)
def body(
_current_cell_index,
_predictions
):
_cell_index_in_current_batch = tf.cast(
_current_cell_index % num_cells_in_batch,
dtype=TYPE_INT
)
row: TYPE_FLOAT = tf.cast(
tf.math.floor(_cell_index_in_current_batch / S),
dtype=TYPE_FLOAT
)
col: TYPE_FLOAT = tf.cast(
_cell_index_in_current_batch % S,
dtype=TYPE_FLOAT
)
_predictions = tf.tensor_scatter_nd_add(
tensor=_predictions,
indices=[
[_current_cell_index, C+1], # p0_x: x in (C,(cp,x,y,w,h))
[_current_cell_index, C+2], # p0_y
[_current_cell_index, C+P+1], # p1_x
[_current_cell_index, C+P+2] # p1_y
],
updates=tf.constant([ <-----
col, # p0_x + col
row, # p0_y + row
col, # p0_x + col
row # p0_y + row
])
)
return [
_current_cell_index+1,
_predictions
]
result = tf.while_loop(
cond=condition,
body=body,
loop_vars=loop_vars
)
型
堆栈跟踪
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
Cell In[33], line 1
----> 1 result = tf.while_loop(
2 cond=condition,
3 body=body,
4 loop_vars=loop_vars
5 )
6 final_cell_index = result[0]
7 updated_predictions = result[1]
File ~/venv/ml/lib/python3.10/site-packages/tensorflow/python/util/deprecation.py:660, in deprecated_arg_values.<locals>.deprecated_wrapper.<locals>.new_func(*args, **kwargs)
652 _PRINTED_WARNING[(func, arg_name)] = True
653 _log_deprecation(
654 'From %s: calling %s (from %s) with %s=%s is deprecated and '
655 'will be removed %s.\nInstructions for updating:\n%s',
(...)
658 'in a future version' if date is None else
659 ('after %s' % date), instructions)
--> 660 return func(*args, **kwargs)
File ~/venv/ml/lib/python3.10/site-packages/tensorflow/python/ops/while_loop.py:241, in while_loop_v2(cond, body, loop_vars, shape_invariants, parallel_iterations, back_prop, swap_memory, maximum_iterations, name)
35 @tf_export("while_loop", v1=[])
36 @deprecation.deprecated_arg_values(
37 None,
(...)
52 maximum_iterations=None,
53 name=None):
54 """Repeat `body` while the condition `cond` is true.
55
56 Note: This op is automatically used in a `tf.function` to convert Python for-
(...)
239
240 """
--> 241 return while_loop(
242 cond=cond,
243 body=body,
244 loop_vars=loop_vars,
245 shape_invariants=shape_invariants,
246 parallel_iterations=parallel_iterations,
247 back_prop=back_prop,
248 swap_memory=swap_memory,
249 name=name,
250 maximum_iterations=maximum_iterations,
251 return_same_structure=True)
File ~/venv/ml/lib/python3.10/site-packages/tensorflow/python/ops/while_loop.py:488, in while_loop(cond, body, loop_vars, shape_invariants, parallel_iterations, back_prop, swap_memory, name, maximum_iterations, return_same_structure)
485 loop_var_structure = nest.map_structure(type_spec.type_spec_from_value,
486 list(loop_vars))
487 while cond(*loop_vars):
--> 488 loop_vars = body(*loop_vars)
489 if try_to_pack and not isinstance(loop_vars, (list, tuple)):
490 packed = True
Cell In[32], line 32, in body(_current_cell_index, _predictions)
15 col: TYPE_FLOAT = tf.cast(
16 _cell_index_in_current_batch % S,
17 dtype=TYPE_FLOAT
18 )
19 # tf.print("_current_cell_index", _current_cell_index)
20 # tf.print("_cell_index_in_current_batch", _cell_index_in_current_batch)
21 # tf.print("row", row)
22 # tf.print("col", col)
24 _predictions = tf.tensor_scatter_nd_add(
25 tensor=_predictions,
26 indices=[
27 [_current_cell_index, C+1], # p0_x: x in (C,(cp,x,y,w,h))
28 [_current_cell_index, C+2], # p0_y
29 [_current_cell_index, C+P+1], # p1_x
30 [_current_cell_index, C+P+2] # p1_y
31 ],
---> 32 updates=tf.constant([
33 col, # p0_x + col
34 row, # p0_y + row
35 col, # p0_x + col
36 row # p0_y + row
37 ])
38 )
39 return [
40 _current_cell_index+1,
41 _predictions
42 ]
File ~/venv/ml/lib/python3.10/site-packages/tensorflow/python/framework/constant_op.py:267, in constant(value, dtype, shape, name)
170 @tf_export("constant", v1=[])
171 def constant(value, dtype=None, shape=None, name="Const"):
172 """Creates a constant tensor from a tensor-like object.
173
174 Note: All eager `tf.Tensor` values are immutable (in contrast to
(...)
265 ValueError: if called on a symbolic tensor.
266 """
--> 267 return _constant_impl(value, dtype, shape, name, verify_shape=False,
268 allow_broadcast=True)
File ~/venv/ml/lib/python3.10/site-packages/tensorflow/python/framework/constant_op.py:279, in _constant_impl(value, dtype, shape, name, verify_shape, allow_broadcast)
277 with trace.Trace("tf.constant"):
278 return _constant_eager_impl(ctx, value, dtype, shape, verify_shape)
--> 279 return _constant_eager_impl(ctx, value, dtype, shape, verify_shape)
281 const_tensor = ops._create_graph_constant( # pylint: disable=protected-access
282 value, dtype, shape, name, verify_shape, allow_broadcast
283 )
284 return const_tensor
File ~/venv/ml/lib/python3.10/site-packages/tensorflow/python/framework/constant_op.py:289, in _constant_eager_impl(ctx, value, dtype, shape, verify_shape)
287 def _constant_eager_impl(ctx, value, dtype, shape, verify_shape):
288 """Creates a constant on the current device."""
--> 289 t = convert_to_eager_tensor(value, ctx, dtype)
290 if shape is None:
291 return t
File ~/venv/ml/lib/python3.10/site-packages/tensorflow/python/framework/constant_op.py:102, in convert_to_eager_tensor(value, ctx, dtype)
100 dtype = dtypes.as_dtype(dtype).as_datatype_enum
101 ctx.ensure_initialized()
--> 102 return ops.EagerTensor(value, ctx.device_name, dtype)
ValueError: TypeError: Scalar tensor has no `len()`
Traceback (most recent call last):
File "/home/user/venv/ml/lib/python3.10/site-packages/tensorflow/python/framework/ops.py", line 338, in __len__
raise TypeError("Scalar tensor has no `len()`")
TypeError: Scalar tensor has no `len()`
型
环境
Python 3.10.12
TensorFlow version: 2.14.0
Ubuntu VERSION="22.04 LTS"
型
1条答案
按热度按时间6qftjkof1#
在
tf.constant
中,您正在传递类似Tensor的对象(row
,col
),但此API需要非Tensor输入。字符串
对于上述情况,您可以使用
tf.convert_to_tensor
。从文档中:型