py_functionMap返回批量tensorflow数据集的未知形状

iqxoj9l9  于 2023-01-17  发布在  其他
关注(0)|答案(1)|浏览(122)

我正在尝试使用OpenCV通过tensorflow数据集管道编写一个预处理函数。下面的this post在我的情况下不起作用。
为了阐明我的观点,考虑这个虚拟Tensor:

import tensorflow as tf
import numpy as np
ds1 = tf.random.uniform(
    (6,5,4,3),
    minval=0,
    maxval=None,
    dtype=tf.dtypes.float64,
    seed=None,
    name=None
)
ds2 = tf.data.Dataset.from_tensor_slices(ds1).batch(batch_size=2)
ds2
Out[4]: <BatchDataset element_spec=TensorSpec(shape=(None, 5, 4, 3), dtype=tf.float64, name=None)>

接下来,我的目标是对这些"阵列"应用预处理步骤(在实践中使用tf.keras.preprocessing.image_dataset_from_directory获得的A.K. A图像...)
一些伪函数:

def preprocess_images(x):
    return x+1

def parse_func_decorator(x):
    return tf.py_function(preprocess_images, [x], tf.float64)

现在我想解开的谜团开始了:通过py_function应用预处理函数给出未知形状:

ds3 = ds2.map(parse_func_decorator)
ds3
Out[7]: <MapDataset element_spec=TensorSpec(shape=<unknown>, dtype=tf.float64, name=None)>

另一方面,直接Map预处理函数保存维数

ds5 = ds2.map(preprocess_images)
ds5
Out[9]: <MapDataset element_spec=TensorSpec(shape=(None, 5, 4, 3), dtype=tf.float64, name=None)>

我错过了什么?

bcs8qyzn

bcs8qyzn1#

最后,我发现Mapkeras输入层可以解决这个问题。

data_rescale = tf.keras.Sequential([
    tf.keras.layers.Input(shape=(5, 4, 3))
])

ds4 = ds2.map(lambda x: (data_rescale(x)))
ds4
Out[19]: <MapDataset element_spec=TensorSpec(shape=(None, 5, 4, 3), dtype=tf.float32, name=None)>

相关问题