tensorflow 数据扩充如何与使用map方法调用的预处理函数一起工作

2jcobegt  于 2023-03-30  发布在  其他
关注(0)|答案(1)|浏览(144)

在一个教程中,我发现了这样的代码来进行数据增强:

def preprocess_with_augmentation(image, label):
   resized_image = tf.image.resize(image, [224, 224])

  # data augmentation with Tensorflow
    augmented_image = tf.image.random_flip_left_right(resized_image)
    augmented_image = tf.image.random_hue(augmented_image, 0.10)
    augmented_image = tf.image.random_brightness(augmented_image, 0.06)
    augmented_image = tf.image.random_contrast(augmented_image, 0.65, 1.35)

  # run Xceptions preprocessing function
    preprocessed_image = tf.keras.applications.xception.preprocess_input(augmented_image)

    print("Working on next ")
    return preprocessed_image, label

该函数的使用方法如下:

train_data = tfds.load('tf_flowers', split="train[:80%]", as_supervised=True)
test_data  = tfds.load('tf_flowers', split="train[80%:100%]", as_supervised=True)
x_augmented_train = train_data.map(preprocess_with_augmentation).batch(32).prefetch(1)
...
history = augmentation_model.fit(x_augmented_train,epochs=10, validation_data=test_data)

这是如何创建增强数据集的?我的猜测是,对数据集进行多次迭代会多次应用预处理函数,并以这种方式为每个epoch创建一个新的增强数据集。这将需要一次又一次地调用预处理函数,并且随机增强是独立的。
因为我不确定它是如何工作的,所以我将print-statement添加到preprocess_with_augmentation函数中。但是,如果我调用programm,“Working on next”只在开始时打印一次,而不是在不同的时期打印。
如果我的推测是正确的,我应该被印很多次。
我想也许在调用fit函数时输出被抑制了,所以我把print改为递增一个计数器,这没有帮助。计数器只显示了一次调用。
我尝试的下一件事是:我使用it = iter(x_augmented_train)创建了一个迭代器,并创建了一个循环来打印每100张图像。但是如果我创建另一个迭代器it2 = iter(x_augmented_train),我希望图像看起来会有所不同,因为每个时期的增强不应该是相同的。但是图像是相同的,所以我想知道这个方法是如何工作的。也许它不起作用?

lsmepo6l

lsmepo6l1#

当你在tf.data上使用map()时,它是在Graph模式下执行的。因此python的print语句只会打印一次,你需要把它改成tf.print才能看到实际的过程。
查看更多详情:https://www.tensorflow.org/guide/intro_to_graphs#using_tffunction

相关问题