tensorflow 如何从dtype为string的tf.tensor中获取字符串值

llew8vvj  于 2022-11-16  发布在  其他
关注(0)|答案(4)|浏览(280)

我想使用tf.data.dataset.list_files函数来填充我的数据集。
但是因为文件不是图像,我需要手动加载它。
问题是tf.data.dataset.list_files将变量作为tf.tensor传递,而我的python代码无法处理tensor。
如何从tf.tensor中获取字符串值?dtype是string。

train_dataset = tf.data.Dataset.list_files(PATH+'clean_4s_val/*.wav')
train_dataset = train_dataset.map(lambda x: load_audio_file(x))

def load_audio_file(file_path):
  print("file_path: ", file_path)
  # i want do something like string_path = convert_tensor_to_string(file_path)

文件路径为Tensor("arg0:0", shape=(), dtype=string)
我使用了tensorflow 1.13.1和渴望模式。
先谢了

83qze16e

83qze16e1#

您可以使用tf.py_func来 Package load_audio_file()

import tensorflow as tf

tf.enable_eager_execution()

def load_audio_file(file_path):
    # you should decode bytes type to string type
    print("file_path: ",bytes.decode(file_path),type(bytes.decode(file_path)))
    return file_path

train_dataset = tf.data.Dataset.list_files('clean_4s_val/*.wav')
train_dataset = train_dataset.map(lambda x: tf.py_func(load_audio_file, [x], [tf.string]))

for one_element in train_dataset:
    print(one_element)

file_path:  clean_4s_val/1.wav <class 'str'>
(<tf.Tensor: id=32, shape=(), dtype=string, numpy=b'clean_4s_val/1.wav'>,)
file_path:  clean_4s_val/3.wav <class 'str'>
(<tf.Tensor: id=34, shape=(), dtype=string, numpy=b'clean_4s_val/3.wav'>,)
file_path:  clean_4s_val/2.wav <class 'str'>
(<tf.Tensor: id=36, shape=(), dtype=string, numpy=b'clean_4s_val/2.wav'>,)

更新TF 2

上述解决方案将不适用于TF 2(使用2.2.0测试),即使将tf.py_func替换为tf.py_function,也会出现以下情况

InvalidArgumentError: TypeError: descriptor 'decode' requires a 'bytes' object but received a 'tensorflow.python.framework.ops.EagerTensor'

要使其在TF 2中工作,请进行以下更改:

  • 删除tf.enable_eager_execution()(eager在TF 2中是enabled by default,您可以使用返回Truetf.executing_eagerly()来验证它)
  • tf.py_func替换为tf.py_function
  • file_path的所有函数内引用替换为file_path.numpy()
envsm3lx

envsm3lx2#

如果你想做一些完全自定义的事情,那么你应该把你的代码 Package 在tf.py_function中。请记住,这将导致性能下降。请参阅这里的文档和示例:
https://www.tensorflow.org/api_docs/python/tf/data/Dataset#map
另一方面,如果你做的是一般性的事情,那么你不需要用py_function来 Package 你的代码,而是使用tf.strings模块中提供的任何方法。这些方法是为了处理字符串Tensor而设计的,并提供了许多常见的方法,如split、join、len等。这些方法不会对性能产生负面影响,它们将直接处理Tensor并返回修改后的Tensor。
请在此处查看tf.strings的文档:https://www.tensorflow.org/api_docs/python/tf/strings
例如,假设您要从文件名中提取标签的名称,则可以编写如下代码:

ds.map(lambda x: tf.strings.split(x, sep='$')[1])

以上假设标签由$分隔。

iqxoj9l9

iqxoj9l93#

如果您真的只想将Tensor展开到其字符串内容-您需要序列化TFRecord以便使用tf_example.SerializeToString()-以获取(可打印的)字符串值-请参见here

w8f9ii69

w8f9ii694#

你可以在bytes对象上使用.decode("utf-8")函数,这是你在对Tensor应用.numpy()方法后得到的

相关问题