tensorflow 混淆tf.data.Dataset管道转换的用法

xmq68pz9  于 2023-04-21  发布在  其他
关注(0)|答案(1)|浏览(121)

我是tf.data API的新手,我试图通过使用存储在磁盘上的图像构建图像分类模型来学习它的工作原理以及如何正确使用它。
我一直在这里学习教程(来自Tensorflow.org)。我了解了它的要点,加载/处理似乎工作正常。问题从 * 配置数据集以获得性能 * 开始。我有一个函数定义如下:

def config_ds(ds):
    ds = ds.shuffle(buffer_size=ds.cardinality().numpy())
    ds = ds.map(process_img,num_parallel_calls=AUTOTUNE)
#     ds = ds.map(augment_img,num_parallel_calls=AUTOTUNE)
    ds = ds.cache()
    ds = ds.batch(batch_size)
    ds = ds.prefetch(buffer_size=AUTOTUNE)
    return ds

(Note augment函数的Map被注解掉了-我还没有使用增强,但我想在将来使用,所以我把它留在这里)。这似乎是可行的,因为我可以生成和绘制/检查一批图像,但它相当慢,总是输出以下消息:

The calling iterator did not fully read the dataset being cached. In order to avoid
unexpected truncation of the dataset, the partially cached contents of the dataset  
will be discarded. This can happen if you have an input pipeline similar to 
`dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` 
instead.

这是否意味着数据是从磁盘中完整读取的,因此没有利用性能优化功能?我一直在阅读有关shufflecacheprefetchrepeat函数的信息,但我对它们的理解还不够好,无法理解警告信息。这里还有其他问题(eg)使我认为问题与批大小没有均匀划分数据有关,但是我试过改变批处理大小,警告仍然存在。我也试过按照警告的建议改变函数的顺序(假设在我的例子中take()batch()表示,对吗?),无济于事。

nhhxz33t

nhhxz33t1#

你的数据集有多大?我的第一个猜测是所有这些图像在某个时候都不适合RAM内存。
我不知道你是否意识到,但是你试图在第一个epoch上将map之后处理的所有内容放入RAM内存。这是当你在cache()调用中没有提供任何file_path时的默认行为。如果你这样做了,那么它会保存到磁盘中的一个文件中(这将在不同的运行中持续存在)

相关问题