我正在尝试使用tensorflow
函数向函数传递参数列表:tf.map_fn
。下面是我的代码:
def my_func(a,v,c,d):
print(a,v,c,d)
if __name__ == '__main__':
tf.config.set_visible_devices(tf.config.list_physical_devices('GPU')[0],'GPU')
iterable = [['a','b','c','s'],['s','e','f','c']]
tensor = tf.convert_to_tensor(iterable)
dataset = tf.data.Dataset.from_tensor_slices(tensor)
tf.map_fn(lambda x: my_func(*x),dataset)
但我发现了一个我无法解释的错误:
Traceback (most recent call last):
File "/Volumes/WorkSSD/Notebooks/01 Convert Raw EDF to Raw CSV copy.py", line 134, in <module>
tf.map_fn(lambda x: my_func(*x),dataset)
File "/Users/fabiomagarelli/.pyenv/versions/3.10.9/lib/python3.10/site-packages/tensorflow/python/util/deprecation.py", line 629, in new_func
return func(*args, **kwargs)
File "/Users/fabiomagarelli/.pyenv/versions/3.10.9/lib/python3.10/site-packages/tensorflow/python/util/deprecation.py", line 561, in new_func
return func(*args, **kwargs)
File "/Users/fabiomagarelli/.pyenv/versions/3.10.9/lib/python3.10/site-packages/tensorflow/python/ops/map_fn.py", line 640, in map_fn_v2
return map_fn(
File "/Users/fabiomagarelli/.pyenv/versions/3.10.9/lib/python3.10/site-packages/tensorflow/python/util/deprecation.py", line 561, in new_func
return func(*args, **kwargs)
File "/Users/fabiomagarelli/.pyenv/versions/3.10.9/lib/python3.10/site-packages/tensorflow/python/ops/map_fn.py", line 392, in map_fn
result_flat_signature = [
File "/Users/fabiomagarelli/.pyenv/versions/3.10.9/lib/python3.10/site-packages/tensorflow/python/ops/map_fn.py", line 393, in <listcomp>
_most_general_compatible_type(s)._unbatch() # pylint: disable=protected-access
File "/Users/fabiomagarelli/.pyenv/versions/3.10.9/lib/python3.10/site-packages/tensorflow/python/data/ops/dataset_ops.py", line 4587, in _unbatch
raise ValueError("Slicing dataset elements is not supported for rank 0.")
ValueError: Slicing dataset elements is not supported for rank 0.
我做错了什么,我该如何解决?
1条答案
按热度按时间n9vozmp41#
如果你的目标是打印,你可以像这样打印。
这个也行。