我想从BatchDataSet中检索前N个项。我尝试了许多不同的方法来实现此目的,它们在重新计算时都检索到 * 不同 * 的项。但是,我想检索N个实际项,而不是将继续检索新项的迭代器。
import tensorflow as tf
import numpy as np
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.models import Sequential
import matplotlib.pyplot as plt
ds = tf.keras.utils.image_dataset_from_directory(
"Images",
validation_split=0.2,
seed=123,
subset="training")
# Attempt to retrieve 9 items
test_ds = ds.take(9)
# Plot the 9 items and their labels
plt.figure(figsize=(4, 4))
for images, labels in test_ds:
for i in range(9):
ax = plt.subplot(3, 3, i + 1)
plt.imshow(images[i].numpy().astype("uint8"))
plt.title(class_names[labels[i]])
plt.axis("off")
#
# AGAIN, plot the 9 items and their labels
# NOTE: This will show 9 different images, and my expectation is
# that it should show the same images as above.
#
plt.figure(figsize=(4, 4))
for images, labels in test_ds:
for i in range(9):
ax = plt.subplot(3, 3, i + 1)
plt.imshow(images[i].numpy().astype("uint8"))
plt.title(class_names[labels[i]])
plt.axis("off")
1条答案
按热度按时间baubqpgj1#
对
tf.data.Dataset
的迭代每次都会触发shuffling。您可以将shuffle
设置为False
以获得确定性结果:第一次
如果您对其他数据样本感兴趣,则可以使用
tf.data.Dataset.skip
和tf.data.Dataset.take
方法。