Python - Tensorflow:如何正确地将函数Map到数据集

hgb9j2n6  于 2023-11-21  发布在  Python
关注(0)|答案(2)|浏览(134)

我正在学习机器学习课程,在解决给定代码的问题时遇到了一些麻烦。

import tensorflow as tf
import tensorflow_datasets as tfds

(data), info = tfds.load("iris", with_info=True, split="train")
print(info.splits)

data = data.shuffle(150)
train_data = data.take(120)
test_data = data.skip(120)

def preprocess(dataset):
    
    def _preprocess_img(image, label):
        label = tf.one_hot(label, depth=3)
        return image, label
            
    dataset = dataset.map(_preprocess_img)
    return dataset.batch(32).prefetch(tf.data.experimental.AUTOTUNE)

train_data = preprocess(train_data)
test_data = preprocess(test_data)

字符串
这只是一个代码片段,但它应该涵盖了这里的问题区域。我得到了错误消息:TypeError:outer_factory..inner_factory..tf_preprocess_img()missing 1 required positional argument:'label'
我无法解决它,有人知道这里出了什么问题吗?我的意思是,是的,函数需要label参数,但在其他示例中,我看到它似乎可以工作。但我想知道是否数据集的解包没有按预期工作?
我尝试的是改变将要Map的函数,我看了一下数据集的元素,但它真的没有帮助我获得正确的见解。我也在寻找其他例子,但我看不出这里的特定代码有什么问题。

yhuiod9q

yhuiod9q1#

导入模块

import tensorflow as tf
import tensorflow_datasets as tfds

字符串

加载Iris数据集并将其拆分为训练集和测试集

(train_data, test_data), info = tfds.load("iris", with_info=True, split=["train[:120]", "train[120:]"], as_supervised=True)

定义预处理函数

def preprocess(image, label):
    # Perform one-hot encoding for the labels
    label = tf.one_hot(label, depth=3)
    return image, label

将预处理函数Map到数据集并进行批量处理

train_data = train_data.map(preprocess).batch(32).prefetch(tf.data.AUTOTUNE)
test_data = test_data.map(preprocess).batch(32).prefetch(tf.data.AUTOTUNE)

可选,您可以验证拆分和第一批数据

print("Training data splits:", train_data.cardinality())
print("Testing data splits:", test_data.cardinality())

迭代数据批次(演示)

for batch in train_data.take(1):
    images, labels = batch
    print("Batch shape:", images.shape)

8ljdwjyq

8ljdwjyq2#

使用默认参数,tfds.load返回一个字典。看:

import tensorflow_datasets as tfds

data = tfds.load("iris", split="train")

next(iter(data))

个字符
这只是一个对象,所以你的预处理函数需要两个。你需要在预处理函数中使用字典格式,或者以另一种格式获取数据。要以tuple的形式获取数据并能够按原样使用你的函数,请在tfds.load中使用as_supervised=True参数。
简化的示例工作(无需更改预处理函数:

import tensorflow_datasets as tfds
import tensorflow as tf

data = tfds.load("iris", split="train", as_supervised=True) 

def preprocess(dataset):
    def _preprocess_img(image, label):
        label = tf.one_hot(label, depth=3)
        return image, label

    dataset = dataset.map(_preprocess_img)
    return dataset.batch(4)

train_data = preprocess(data)

print(next(iter(train_data)))
(<tf.Tensor: shape=(4, 4), dtype=float32, numpy=
array([[5.1, 3.4, 1.5, 0.2],
       [7.7, 3. , 6.1, 2.3],
       [5.7, 2.8, 4.5, 1.3],
       [6.8, 3.2, 5.9, 2.3]], dtype=float32)>, 
<tf.Tensor: shape=(4, 3), dtype=float32, numpy=
array([[1., 0., 0.],
       [0., 0., 1.],
       [0., 1., 0.],
       [0., 0., 1.]], dtype=float32)>)

的字符串

相关问题