在tensorflow数据集中打乱批次

wwtsj6pe  于 2023-06-24  发布在  其他
关注(0)|答案(1)|浏览(92)

我正在阅读English-to-Spanish translation with a sequence-to-sequence Transformer教程。

def make_dataset(pairs, batch_size=64):
    eng_texts, fra_texts = zip(*pairs)
    eng_texts = list(eng_texts)
    fra_texts = list(fra_texts)
    dataset = tf.data.Dataset.from_tensor_slices((eng_texts, fra_texts))
    dataset = dataset.batch(batch_size)
    dataset = dataset.map(format_dataset, num_parallel_calls=4)
    return dataset.shuffle(2048).prefetch(AUTOTUNE).cache()

特别是在dataset.shuffle(2048).prefetch(16).cache()行中
我的问题:
1.据我所知,这里的2048将是存储在缓冲区中的数据点的数量,而不是批处理,但将对批处理进行 Shuffle ,对吗?

  1. prefetch(16)。要预取的批数,对吗?
    编辑:3. map是在每次从数据集获取时应用于批处理,还是只在训练期间第一次应用。
dxxyhpgq

dxxyhpgq1#

问题1

应用Dataset.shuffle()Dataset.batch()转换的顺序可能会对结果数据集产生影响:

  • Dataset.batch()之前应用Dataset.shuffle()
  • 当您在Dataset.batch()之前应用Dataset.shuffle()时,将对数据集的各个元素应用混洗操作。这意味着每个批次中元素的顺序是随机的,但批次本身保持不变。
  • 当您希望在保持批处理结构的同时随机化单个元素的顺序时,这可能很有用。它确保每个批包含随机打乱的元素,但每个批中元素的相对顺序保持一致。
  • Dataset.batch()之后应用Dataset.shuffle()
  • 当在Dataset.batch()之后应用Dataset.shuffle()时,将对整个批次而不是单个元素应用混洗操作。
  • 这意味着批次本身的顺序将是随机的,可能导致时期之间的不同批次组成。
  • 当您希望对批处理本身进行 Shuffle 时,这可能很有用,从而在每个epoch中引入不同的数据分布。它可以帮助您在训练过程中减少批处理顺序的影响,这在处理顺序数据时尤其重要。

问题二

应用Dataset.prefetch()Dataset.batch()转换的顺序可能会影响数据集的行为和性能:

  • Dataset.batch()之前应用Dataset.prefetch()
  • 当您在Dataset.batch()之前应用Dataset.prefetch()时,将对数据集的各个元素执行预取操作。这意味着在模型处理当前批元素的同时,在后台获取和准备下一批元素。
  • 批处理之前的预取允许重叠执行,其中下一批处理的数据准备与模型在当前批处理上的执行同时发生。这可以帮助减少空闲时间并提高数据处理和模型训练的整体效率。
  • 通常建议使用此顺序,因为在批处理之前进行预取可以使流水线执行更顺畅,并提高GPU或CPU利用率。
  • Dataset.batch()之后应用Dataset.prefetch()
  • 当在Dataset.batch()之后应用Dataset.prefetch()时,预取操作将对整个数据批执行,而不是对单个元素执行。
  • 这意味着在模型处理当前批处理时,将在后台获取和准备多个批处理。
  • 批处理后的预取仍然可以通过将多个批处理的数据准备与模型的执行重叠来提供一些性能优势。但是,它可能不如批处理之前的预取有效,因为它在批处理级别而不是在单个元素级别工作。

问题三

如果您想应用一次转换并在多个epoch中重用它,则可以使用cache()方法显式缓存转换后的数据集。这允许将转换后的数据集存储在内存或磁盘上,并在后续时期中重用,而无需重新计算转换。

相关问题