keras 如何从TensorFlow批处理数据集中检索前N个项,而不是从重新计算不同项的迭代器中检索?

qmelpv7a  于 2022-11-13  发布在  其他
关注(0)|答案(1)|浏览(95)

我想从BatchDataSet中检索前N个项。我尝试了许多不同的方法来实现此目的,它们在重新计算时都检索到 * 不同 * 的项。但是,我想检索N个实际项,而不是将继续检索新项的迭代器。

import tensorflow as tf
import numpy as np
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.models import Sequential
import matplotlib.pyplot as plt
ds = tf.keras.utils.image_dataset_from_directory(
    "Images", 
    validation_split=0.2,
    seed=123,
    subset="training")

# Attempt to retrieve 9 items
test_ds = ds.take(9)

# Plot the 9 items and their labels
plt.figure(figsize=(4, 4))
for images, labels in test_ds:
  for i in range(9):
    ax = plt.subplot(3, 3, i + 1)
    plt.imshow(images[i].numpy().astype("uint8"))
    plt.title(class_names[labels[i]])
    plt.axis("off")

#
# AGAIN, plot the 9 items and their labels
# NOTE: This will show 9 different images, and my expectation is 
# that it should show the same images as above.
# 
plt.figure(figsize=(4, 4))
for images, labels in test_ds:
  for i in range(9):
    ax = plt.subplot(3, 3, i + 1)
    plt.imshow(images[i].numpy().astype("uint8"))
    plt.title(class_names[labels[i]])
    plt.axis("off")
baubqpgj

baubqpgj1#

tf.data.Dataset的迭代每次都会触发shuffling。您可以将shuffle设置为False以获得确定性结果:

import tensorflow as tf
import pathlib
import matplotlib.pyplot as plt

dataset_url = "https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz"
data_dir = tf.keras.utils.get_file('flower_photos', origin=dataset_url, untar=True)
data_dir = pathlib.Path(data_dir)

ds = tf.keras.utils.image_dataset_from_directory(
  data_dir,
  validation_split=0.2,
  subset="training",
  seed=123,
  image_size=(64, 64),
  batch_size=1,
  shuffle=False)

# Attempt to retrieve 9 items
test_ds = ds.take(9)

class_names = ['a', 'b', 'c', 'd', 'e']
# Plot the 9 items and their labels
plt.figure(figsize=(4, 4))
for i, (images, labels) in enumerate(test_ds):
  ax = plt.subplot(3, 3, i + 1)
  plt.imshow(images[0, ...].numpy().astype("uint8"))
  plt.title(class_names[labels.numpy()[0]])
  plt.axis("off")

plt.figure(figsize=(4, 4))
for i, (images, labels) in enumerate(test_ds):
  ax = plt.subplot(3, 3, i + 1)
  plt.imshow(images[0, ...].numpy().astype("uint8"))
  plt.title(class_names[labels.numpy()[0]])
  plt.axis("off")

第一次
如果您对其他数据样本感兴趣,则可以使用tf.data.Dataset.skiptf.data.Dataset.take方法。

相关问题