python `tf.data`:`cache`和`repeat`是否兼容,如果是,它们应该以什么顺序应用于`tf.data.Dataset`?

niwlg2el  于 2023-03-21  发布在  Python
关注(0)|答案(1)|浏览(162)

为了确保样本无限重复,可以将repeat方法(以None-1作为count的参数)应用于tf.data.Dataset。但是,由于样本生成过程耗时,因此还需要使用cache方法。这两种方法是否兼容,如果兼容,它们应该以什么顺序应用于tf.data.Datasetthe documentationthis guide都没有解决这个特定问题。
需要无限期重复样本的原因是由于tf.distribute.DistributedDataset要求在tf.keras.Model.fit中指定steps_per_epoch。此外,在每个epoch之后,数据集不会自动重用数据集,而是会耗尽,因此必须包含一定数量的批次,其数量至少等于每个epoch的步骤数乘以epoch数。
注意:由于远程计算环境的一些限制,正在使用TensorFlow-GPU版本2.4.1。

TL;DRtf.data.Dataset.repeattf.data.Dataset.cache是否兼容,如果兼容,应用到tf.data.Dataset的顺序是什么?

db2dz4w8

db2dz4w81#

显然,*.index缓存文件只有在对数据集进行一次完整的迭代后才会创建。因此,可以使用以下代码来检查**何时创建此文件。

import tensorflow as tf, os

n_originals = 100
cache_dir = "./cache_dir/"
cache_file = "cache_file"
os.mkdir(cache_dir)
ds = tf.data.Dataset.range(n_originals).cache(os.path.join(cache_dir, cache_file)).repeat()
for i, sample in enumerate(ds):
    if cache_file+".index" in os.listdir(cache_dir):
        print(i)
        print("Cache files have been created, meaning that they are now being used to iterate over.")
        break

输出:

100
Cache files have been created, meaning that they are now being used to iterate over.

这意味该高速缓存文件是在数据集的一次迭代之后创建的,并且不受repeat方法的影响,该方法然后重复来自缓存的元素。

相关问题