python tensorflow image_dataset_from_directory中的类名称似乎是矩阵而不是向量

mbzjlibv  于 2023-03-28  发布在  Python
关注(0)|答案(1)|浏览(120)

我尝试使用tensorflow.keras.preprocessing.image_dataset_from_directory()将自定义图像数据输入到我的图像分类器中,以根据子目录的名称自动标记图像。代码成功执行,但后来发现图像被标记为[batch_size,3]矩阵,我也不确定数字3是从哪里来的(可能是通道?)。它影响了我的图像无法以这种方式标记。
这是我的代码:

import tensorflow.keras as tfk
datadir = '/content/rockpaperscissors'

train_ds = tfk.preprocessing.image_dataset_from_directory(
    datadir,
    labels =  'inferred',
    label_mode = 'categorical',
    batch_size = 10,
    image_size = (150, 150),
    shuffle = True,
    seed = 123,
    validation_split = 0.3,
    subset = 'training'
)

val_ds = tfk.preprocessing.image_dataset_from_directory(
    datadir,
    labels = 'inferred',
    label_mode = 'categorical',
    batch_size = 10,
    image_size = (150, 150),
    shuffle = True,
    seed = 123,
    validation_split = 0.3,
    subset = 'validation'
)

然后我检查了每批图像和标签数量的一致性

for image_batch, labels_batch in train_ds:
  print(image_batch.shape)
  print(labels_batch.shape)
  break

它输出

(10, 150, 150, 3)
(10, 3)

有人知道如何解决这个问题吗?先谢谢了

ux6nzvsh

ux6nzvsh1#

(10、150、150、3):这里3是图像通道的数量,因为是RGB,所以是3。通常,它是(samplesxHxWxc)
(10,3):这里3是数据集中的类数

相关问题