python-3.x tensorflow -标量Tensor在tensor_scatter_nd_add处没有len()

k10s72fa  于 2023-11-20  发布在  Python
关注(0)|答案(1)|浏览(97)

Tensorflow tensor_scatter_nd_add中错误的原因和机制是什么?
尝试使用tensor_scatter_nd_add来更新Tensor中的元素,例如tf.tensor_scatter_nd_add,它使用tf.constant来创建更新updates = tf.constant([9, 10, 11, 12])

indices = tf.constant([[4], [3], [1], [7]])
updates = tf.constant([9, 10, 11, 12])
tensor = tf.ones([8], dtype=tf.int32)
updated = tf.tensor_scatter_nd_add(tensor, indices, updates)

字符串
然而,下面的代码,使用tf.constant创建更新导致错误ValueError: TypeError: Scalar tensor has no len()

代码摘录

def body(
        _current_cell_index,
        _predictions
):
    _cell_index_in_current_batch = tf.cast(
        _current_cell_index % num_cells_in_batch,
        dtype=TYPE_INT
    )
    
    row: TYPE_FLOAT = tf.cast(
        tf.math.floor(_cell_index_in_current_batch / S),
        dtype=TYPE_FLOAT
    )
    col: TYPE_FLOAT = tf.cast(
        _cell_index_in_current_batch % S,
        dtype=TYPE_FLOAT
    )
    
    _predictions = tf.tensor_scatter_nd_add(
        tensor=_predictions,
        indices=[
            [_current_cell_index, C+1],     # p0_x: x in (C,(cp,x,y,w,h)) 
            [_current_cell_index, C+2],     # p0_y 
            [_current_cell_index, C+P+1],   # p1_x  
            [_current_cell_index, C+P+2]    # p1_y 
        ],
        updates=tf.constant([     <----- ValueError: TypeError: Scalar tensor has no `len()
            col,                            # p0_x + col
            row,                            # p0_y + row
            col,                            # p0_x + col
            row                             # p0_y + row
        ])
    )        
    return [
        _current_cell_index+1,
        _predictions
    ]


删除tf.constant使其工作。

_predictions = tf.tensor_scatter_nd_add(
        tensor=_predictions,
        indices=[
            [_current_cell_index, C+1],     # p0_x: x in (C,(cp,x,y,w,h)) 
            [_current_cell_index, C+2],     # p0_y 
            [_current_cell_index, C+P+1],   # p1_x  
            [_current_cell_index, C+P+2]    # p1_y 
        ],
        updates=[     <----- remove tf.constant make it work
            col,                            # p0_x + col
            row,                            # p0_y + row
            col,                            # p0_x + col
            row                             # p0_y + row
        ]
    )

代码

TYPE_FLOAT = np.float32
TYPE_INT = np.int32
N = 2
S = 3
C = 5
B = 2
P = 5

# Tensor N number of SxS cells where each cell has (C+B*P) elements
predictions: tf.Tensor = tf.reshape(
    tensor=tf.zeros(N*S*S*(C+B*P), dtype=TYPE_FLOAT), 
    shape=(N,S,S,(C+B*P)),
)

# Geometry of the Tensor
num_cells_in_batch = tf.constant(S*S, dtype=TYPE_INT)
num_total_cells = tf.constant(N * num_cells_in_batch, dtype=TYPE_INT)

# Loop vars for tf.while_loop
current_cell_index = tf.constant(0, dtype=TYPE_INT)
loop_vars = (
    current_cell_index,
    # reshape the prediction as a sequence of cells
    tf.reshape(tensor=predictions, shape=(-1, C+B*P))
)

def condition(
        _current_cell_index, 
        _predictions
):
    return tf.less(_current_cell_index, num_total_cells)

def body(
        _current_cell_index,
        _predictions
):
    _cell_index_in_current_batch = tf.cast(
        _current_cell_index % num_cells_in_batch,
        dtype=TYPE_INT
    )
    
    row: TYPE_FLOAT = tf.cast(
        tf.math.floor(_cell_index_in_current_batch / S),
        dtype=TYPE_FLOAT
    )
    col: TYPE_FLOAT = tf.cast(
        _cell_index_in_current_batch % S,
        dtype=TYPE_FLOAT
    )
    
    _predictions = tf.tensor_scatter_nd_add(
        tensor=_predictions,
        indices=[
            [_current_cell_index, C+1],     # p0_x: x in (C,(cp,x,y,w,h)) 
            [_current_cell_index, C+2],     # p0_y 
            [_current_cell_index, C+P+1],   # p1_x  
            [_current_cell_index, C+P+2]    # p1_y 
        ],
        updates=tf.constant([     <-----
            col,                            # p0_x + col
            row,                            # p0_y + row
            col,                            # p0_x + col
            row                             # p0_y + row
        ])
    )        
    return [
        _current_cell_index+1,
        _predictions
    ]

result = tf.while_loop(
    cond=condition,
    body=body,
    loop_vars=loop_vars
)

堆栈跟踪

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Cell In[33], line 1
----> 1 result = tf.while_loop(
      2     cond=condition,
      3     body=body,
      4     loop_vars=loop_vars
      5 )
      6 final_cell_index = result[0]
      7 updated_predictions = result[1]

File ~/venv/ml/lib/python3.10/site-packages/tensorflow/python/util/deprecation.py:660, in deprecated_arg_values.<locals>.deprecated_wrapper.<locals>.new_func(*args, **kwargs)
    652           _PRINTED_WARNING[(func, arg_name)] = True
    653         _log_deprecation(
    654             'From %s: calling %s (from %s) with %s=%s is deprecated and '
    655             'will be removed %s.\nInstructions for updating:\n%s',
   (...)
    658             'in a future version' if date is None else
    659             ('after %s' % date), instructions)
--> 660 return func(*args, **kwargs)

File ~/venv/ml/lib/python3.10/site-packages/tensorflow/python/ops/while_loop.py:241, in while_loop_v2(cond, body, loop_vars, shape_invariants, parallel_iterations, back_prop, swap_memory, maximum_iterations, name)
     35 @tf_export("while_loop", v1=[])
     36 @deprecation.deprecated_arg_values(
     37     None,
   (...)
     52                   maximum_iterations=None,
     53                   name=None):
     54   """Repeat `body` while the condition `cond` is true.
     55 
     56   Note: This op is automatically used in a `tf.function` to convert Python for-
   (...)
    239 
    240   """
--> 241   return while_loop(
    242       cond=cond,
    243       body=body,
    244       loop_vars=loop_vars,
    245       shape_invariants=shape_invariants,
    246       parallel_iterations=parallel_iterations,
    247       back_prop=back_prop,
    248       swap_memory=swap_memory,
    249       name=name,
    250       maximum_iterations=maximum_iterations,
    251       return_same_structure=True)

File ~/venv/ml/lib/python3.10/site-packages/tensorflow/python/ops/while_loop.py:488, in while_loop(cond, body, loop_vars, shape_invariants, parallel_iterations, back_prop, swap_memory, name, maximum_iterations, return_same_structure)
    485 loop_var_structure = nest.map_structure(type_spec.type_spec_from_value,
    486                                         list(loop_vars))
    487 while cond(*loop_vars):
--> 488   loop_vars = body(*loop_vars)
    489   if try_to_pack and not isinstance(loop_vars, (list, tuple)):
    490     packed = True

Cell In[32], line 32, in body(_current_cell_index, _predictions)
     15 col: TYPE_FLOAT = tf.cast(
     16     _cell_index_in_current_batch % S,
     17     dtype=TYPE_FLOAT
     18 )
     19 # tf.print("_current_cell_index", _current_cell_index)
     20 # tf.print("_cell_index_in_current_batch", _cell_index_in_current_batch)
     21 # tf.print("row", row)
     22 # tf.print("col", col)    
     24 _predictions = tf.tensor_scatter_nd_add(
     25     tensor=_predictions,
     26     indices=[
     27         [_current_cell_index, C+1],     # p0_x: x in (C,(cp,x,y,w,h)) 
     28         [_current_cell_index, C+2],     # p0_y 
     29         [_current_cell_index, C+P+1],   # p1_x  
     30         [_current_cell_index, C+P+2]    # p1_y 
     31     ],
---> 32     updates=tf.constant([
     33         col,                            # p0_x + col
     34         row,                            # p0_y + row
     35         col,                            # p0_x + col
     36         row                             # p0_y + row
     37     ])
     38 )        
     39 return [
     40     _current_cell_index+1,
     41     _predictions
     42 ]

File ~/venv/ml/lib/python3.10/site-packages/tensorflow/python/framework/constant_op.py:267, in constant(value, dtype, shape, name)
    170 @tf_export("constant", v1=[])
    171 def constant(value, dtype=None, shape=None, name="Const"):
    172   """Creates a constant tensor from a tensor-like object.
    173 
    174   Note: All eager `tf.Tensor` values are immutable (in contrast to
   (...)
    265     ValueError: if called on a symbolic tensor.
    266   """
--> 267   return _constant_impl(value, dtype, shape, name, verify_shape=False,
    268                         allow_broadcast=True)

File ~/venv/ml/lib/python3.10/site-packages/tensorflow/python/framework/constant_op.py:279, in _constant_impl(value, dtype, shape, name, verify_shape, allow_broadcast)
    277     with trace.Trace("tf.constant"):
    278       return _constant_eager_impl(ctx, value, dtype, shape, verify_shape)
--> 279   return _constant_eager_impl(ctx, value, dtype, shape, verify_shape)
    281 const_tensor = ops._create_graph_constant(  # pylint: disable=protected-access
    282     value, dtype, shape, name, verify_shape, allow_broadcast
    283 )
    284 return const_tensor

File ~/venv/ml/lib/python3.10/site-packages/tensorflow/python/framework/constant_op.py:289, in _constant_eager_impl(ctx, value, dtype, shape, verify_shape)
    287 def _constant_eager_impl(ctx, value, dtype, shape, verify_shape):
    288   """Creates a constant on the current device."""
--> 289   t = convert_to_eager_tensor(value, ctx, dtype)
    290   if shape is None:
    291     return t

File ~/venv/ml/lib/python3.10/site-packages/tensorflow/python/framework/constant_op.py:102, in convert_to_eager_tensor(value, ctx, dtype)
    100     dtype = dtypes.as_dtype(dtype).as_datatype_enum
    101 ctx.ensure_initialized()
--> 102 return ops.EagerTensor(value, ctx.device_name, dtype)

ValueError: TypeError: Scalar tensor has no `len()`
Traceback (most recent call last):

  File "/home/user/venv/ml/lib/python3.10/site-packages/tensorflow/python/framework/ops.py", line 338, in __len__
    raise TypeError("Scalar tensor has no `len()`")

TypeError: Scalar tensor has no `len()`

环境

Python 3.10.12
TensorFlow version: 2.14.0
Ubuntu VERSION="22.04 LTS"

6qftjkof

6qftjkof1#

tf.constant中,您正在传递类似Tensor的对象(rowcol),但此API需要非Tensor输入。

valid: tf.constant([1, 2])
invalid: tf.constant([tf.cast(1, 'int32'), tf.cast(2, 'int32')])

字符串
对于上述情况,您可以使用tf.convert_to_tensor。从文档中:

tf.constant: 
   A constant value (or list) of output type dtype.

tf.convert_to_tensor
  An object whose type has a registered Tensor conversion function.

相关问题