我正在尝试使用LeNet架构训练mnist数据库。
我从github(https://github.com/myleott/mnist_png)下载了mnist_png图像,它有超过50000张图像。我试图建立一个LeNet模型的预测手写数字使用LeNet架构这是写使用keras
用于生成图像的代码。
train_ds = tf.keras.utils.image_dataset_from_directory(
'mnist_png/training/',
validation_split = 0.2,
subset = "training",
seed = 123,
image_size = (32, 32),
batch_size = 100)
val_ds = tf.keras.utils.image_dataset_from_directory(
'mnist_png/training/',
validation_split = 0.2,
subset = "validation",
seed = 123,
image_size = (32, 32),
batch_size = 100)
test_ds = tf.keras.utils.image_dataset_from_directory(
'mnist_png/testing/',
seed = 123,
image_size = (32, 32),
batch_size = 1000)
输出量
Found 40818 files belonging to 7 classes.
Using 32655 files for training.
Found 40818 files belonging to 7 classes.
Using 8163 files for validation.
Found 10000 files belonging to 10 classes.
输入形状= (32, 32, 3)
我的模型摘要
Model: "sequential"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
conv2d (Conv2D) (None, 28, 28, 6) 456
average_pooling2d (AverageP (None, 14, 14, 6) 0
ooling2D)
activation (Activation) (None, 14, 14, 6) 0
conv2d_1 (Conv2D) (None, 10, 10, 16) 2416
average_pooling2d_1 (Averag (None, 5, 5, 16) 0
ePooling2D)
activation_1 (Activation) (None, 5, 5, 16) 0
conv2d_2 (Conv2D) (None, 1, 1, 120) 48120
flatten (Flatten) (None, 120) 0
dense (Dense) (None, 84) 10164
dense_1 (Dense) (None, 10) 850
=================================================================
Total params: 62,006
Trainable params: 62,006
Non-trainable params: 0
使用此代码编译的模型
model.compile(optimizer='adam', loss=losses.sparse_categorical_crossentropy, metrics=['accuracy'])
我已经训练了10个epoch,我得到了这个输出-
Epoch 1/10
327/327 [==============================] - 31s 79ms/step - loss: 0.9729 - accuracy: 0.6456 - val_loss: 0.3609 - val_accuracy: 0.8951
Epoch 2/10
327/327 [==============================] - 25s 77ms/step - loss: 0.3036 - accuracy: 0.9021 - val_loss: 0.2276 - val_accuracy: 0.9330
Epoch 3/10
327/327 [==============================] - 28s 85ms/step - loss: 0.2170 - accuracy: 0.9307 - val_loss: 0.1862 - val_accuracy: 0.9389
Epoch 4/10
327/327 [==============================] - 29s 89ms/step - loss: 0.1778 - accuracy: 0.9433 - val_loss: 0.1892 - val_accuracy: 0.9401
Epoch 5/10
327/327 [==============================] - 25s 76ms/step - loss: 0.1521 - accuracy: 0.9519 - val_loss: 0.1692 - val_accuracy: 0.9476
Epoch 6/10
327/327 [==============================] - 27s 83ms/step - loss: 0.1392 - accuracy: 0.9553 - val_loss: 0.1340 - val_accuracy: 0.9588
Epoch 7/10
327/327 [==============================] - 26s 79ms/step - loss: 0.1203 - accuracy: 0.9609 - val_loss: 0.1131 - val_accuracy: 0.9632
Epoch 8/10
327/327 [==============================] - 25s 76ms/step - loss: 0.1128 - accuracy: 0.9644 - val_loss: 0.1170 - val_accuracy: 0.9644
Epoch 9/10
327/327 [==============================] - 27s 81ms/step - loss: 0.1061 - accuracy: 0.9663 - val_loss: 0.1051 - val_accuracy: 0.9659
Epoch 10/10
327/327 [==============================] - 29s 89ms/step - loss: 0.0968 - accuracy: 0.9699 - val_loss: 0.0950 - val_accuracy: 0.9705
当我运行model.evaluate(test)
时,我得到了高损失和低准确性。
10/10 [==============================] - 4s 200ms/step - loss: 9.2694 - accuracy: 0.0656
有什么原因吗?
2条答案
按热度按时间42fyovps1#
似乎没有什么明显的错误。在test_ds中尝试设置shuffle=False。为了得到一个线索,尝试在val_ds上运行model.evaluate,看看它是否给出正确的结果。我唯一能想到的就是测试数据有问题。看看几张图片,看看它们的相关标签是否正确。
smdnsysy2#
训练数据集不完整
看起来你有一个不完整的数据集。正如您在加载文件后的输出中所看到的,它说(引用您问题中的输出):
请注意,前两个是训练和验证数据集,它们只看到40818个文件,总共有7个类,而最后一个是测试数据集,它看到了所有10个类。这意味着你只训练了7个类,而你的模型从未见过其他3个类。
如果我运行以下代码(这些是我的Jupyter笔记本中的单独单元格,您可以将其粘贴到Colab中以轻松运行),它会找到所有10个类:
输出:
因此,您应该首先解决这个问题,并确保您拥有完整的数据集,然后您可能会在测试数据集上获得更好的结果。
灰度对比RGB图像
我还建议在调用
image_dataset_from_directory()
时指定color_mode='grayscale'
,因为数据集本身是灰度的(您可以使用PIL
或matplotlib
库来验证这一点),image_dataset_from_directory()
的默认设置是将每个图像放大到RGB(3个通道),这就是您最终获得3通道输入的方式,这只是将单个灰度通道复制3次。