我尝试使用tensorflow.keras.preprocessing.image_dataset_from_directory()
将自定义图像数据输入到我的图像分类器中,以根据子目录的名称自动标记图像。代码成功执行,但后来发现图像被标记为[batch_size,3]矩阵,我也不确定数字3是从哪里来的(可能是通道?)。它影响了我的图像无法以这种方式标记。
这是我的代码:
import tensorflow.keras as tfk
datadir = '/content/rockpaperscissors'
train_ds = tfk.preprocessing.image_dataset_from_directory(
datadir,
labels = 'inferred',
label_mode = 'categorical',
batch_size = 10,
image_size = (150, 150),
shuffle = True,
seed = 123,
validation_split = 0.3,
subset = 'training'
)
val_ds = tfk.preprocessing.image_dataset_from_directory(
datadir,
labels = 'inferred',
label_mode = 'categorical',
batch_size = 10,
image_size = (150, 150),
shuffle = True,
seed = 123,
validation_split = 0.3,
subset = 'validation'
)
然后我检查了每批图像和标签数量的一致性
for image_batch, labels_batch in train_ds:
print(image_batch.shape)
print(labels_batch.shape)
break
它输出
(10, 150, 150, 3)
(10, 3)
有人知道如何解决这个问题吗?先谢谢了
1条答案
按热度按时间ux6nzvsh1#
(10、150、150、3):这里3是图像通道的数量,因为是RGB,所以是3。通常,它是(samplesxHxWxc)
(10,3):这里3是数据集中的类数