tensorflow map_fn与字符串列表的列表

w46czmvw  于 2023-06-24  发布在  其他
关注(0)|答案(1)|浏览(137)

我正在尝试使用tensorflow函数向函数传递参数列表:tf.map_fn。下面是我的代码:

def my_func(a,v,c,d):
    print(a,v,c,d)

if __name__ == '__main__':
    tf.config.set_visible_devices(tf.config.list_physical_devices('GPU')[0],'GPU')

    iterable = [['a','b','c','s'],['s','e','f','c']]
    tensor = tf.convert_to_tensor(iterable)
    dataset = tf.data.Dataset.from_tensor_slices(tensor)
    tf.map_fn(lambda x: my_func(*x),dataset)

但我发现了一个我无法解释的错误:

Traceback (most recent call last):
  File "/Volumes/WorkSSD/Notebooks/01 Convert Raw EDF to Raw CSV copy.py", line 134, in <module>
    tf.map_fn(lambda x: my_func(*x),dataset)
  File "/Users/fabiomagarelli/.pyenv/versions/3.10.9/lib/python3.10/site-packages/tensorflow/python/util/deprecation.py", line 629, in new_func
    return func(*args, **kwargs)
  File "/Users/fabiomagarelli/.pyenv/versions/3.10.9/lib/python3.10/site-packages/tensorflow/python/util/deprecation.py", line 561, in new_func
    return func(*args, **kwargs)
  File "/Users/fabiomagarelli/.pyenv/versions/3.10.9/lib/python3.10/site-packages/tensorflow/python/ops/map_fn.py", line 640, in map_fn_v2
    return map_fn(
  File "/Users/fabiomagarelli/.pyenv/versions/3.10.9/lib/python3.10/site-packages/tensorflow/python/util/deprecation.py", line 561, in new_func
    return func(*args, **kwargs)
  File "/Users/fabiomagarelli/.pyenv/versions/3.10.9/lib/python3.10/site-packages/tensorflow/python/ops/map_fn.py", line 392, in map_fn
    result_flat_signature = [
  File "/Users/fabiomagarelli/.pyenv/versions/3.10.9/lib/python3.10/site-packages/tensorflow/python/ops/map_fn.py", line 393, in <listcomp>
    _most_general_compatible_type(s)._unbatch()  # pylint: disable=protected-access
  File "/Users/fabiomagarelli/.pyenv/versions/3.10.9/lib/python3.10/site-packages/tensorflow/python/data/ops/dataset_ops.py", line 4587, in _unbatch
    raise ValueError("Slicing dataset elements is not supported for rank 0.")
ValueError: Slicing dataset elements is not supported for rank 0.

我做错了什么,我该如何解决?

n9vozmp4

n9vozmp41#

如果你的目标是打印,你可以像这样打印。

import tensorflow as tf

if __name__ == '__main__':

    iterable = [['a','b','c','s'],['s','e','f','c']]
    tensor = tf.convert_to_tensor(iterable)
    dataset = tf.data.Dataset.from_tensor_slices(tensor)
    for element in dataset:
        print(element)

这个也行。

def my_func(a):
    tf.print(a,[a])
    return a

if __name__ == '__main__':

    iterable = [[['a','b','c','s'],['s','e','f','c']]]
    tensor = tf.convert_to_tensor(iterable)
    dataset = tf.data.Dataset.from_tensor_slices(tensor)
    iterator = iter(dataset)
    tf.map_fn(my_func, iterator.get_next())

相关问题