Keras LeNet训练和验证精度高,但测试精度低

x9ybnkn6  于 2023-06-30  发布在  其他
关注(0)|答案(2)|浏览(139)

我正在尝试使用LeNet架构训练mnist数据库。
我从github(https://github.com/myleott/mnist_png)下载了mnist_png图像,它有超过50000张图像。我试图建立一个LeNet模型的预测手写数字使用LeNet架构这是写使用keras
用于生成图像的代码。

train_ds = tf.keras.utils.image_dataset_from_directory(
  'mnist_png/training/',
  validation_split = 0.2,
  subset = "training",
  seed = 123,
  image_size = (32, 32),
  batch_size = 100)

val_ds = tf.keras.utils.image_dataset_from_directory(
  'mnist_png/training/',
  validation_split = 0.2,
  subset = "validation",
  seed = 123,
  image_size = (32, 32),
  batch_size = 100)

test_ds = tf.keras.utils.image_dataset_from_directory(
  'mnist_png/testing/',
  seed = 123,
  image_size = (32, 32),
  batch_size = 1000)

输出量

Found 40818 files belonging to 7 classes.
Using 32655 files for training.
Found 40818 files belonging to 7 classes.
Using 8163 files for validation.
Found 10000 files belonging to 10 classes.

输入形状= (32, 32, 3)
我的模型摘要

Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 conv2d (Conv2D)             (None, 28, 28, 6)         456       
                                                                 
 average_pooling2d (AverageP  (None, 14, 14, 6)        0         
 ooling2D)                                                       
                                                                 
 activation (Activation)     (None, 14, 14, 6)         0         
                                                                 
 conv2d_1 (Conv2D)           (None, 10, 10, 16)        2416      
                                                                 
 average_pooling2d_1 (Averag  (None, 5, 5, 16)         0         
 ePooling2D)                                                     
                                                                 
 activation_1 (Activation)   (None, 5, 5, 16)          0         
                                                                 
 conv2d_2 (Conv2D)           (None, 1, 1, 120)         48120     
                                                                 
 flatten (Flatten)           (None, 120)               0         
                                                                 
 dense (Dense)               (None, 84)                10164     
                                                                 
 dense_1 (Dense)             (None, 10)                850       
                                                                 
=================================================================
Total params: 62,006
Trainable params: 62,006
Non-trainable params: 0

使用此代码编译的模型

model.compile(optimizer='adam', loss=losses.sparse_categorical_crossentropy, metrics=['accuracy'])

我已经训练了10个epoch,我得到了这个输出-

Epoch 1/10
327/327 [==============================] - 31s 79ms/step - loss: 0.9729 - accuracy: 0.6456 - val_loss: 0.3609 - val_accuracy: 0.8951
Epoch 2/10
327/327 [==============================] - 25s 77ms/step - loss: 0.3036 - accuracy: 0.9021 - val_loss: 0.2276 - val_accuracy: 0.9330
Epoch 3/10
327/327 [==============================] - 28s 85ms/step - loss: 0.2170 - accuracy: 0.9307 - val_loss: 0.1862 - val_accuracy: 0.9389
Epoch 4/10
327/327 [==============================] - 29s 89ms/step - loss: 0.1778 - accuracy: 0.9433 - val_loss: 0.1892 - val_accuracy: 0.9401
Epoch 5/10
327/327 [==============================] - 25s 76ms/step - loss: 0.1521 - accuracy: 0.9519 - val_loss: 0.1692 - val_accuracy: 0.9476
Epoch 6/10
327/327 [==============================] - 27s 83ms/step - loss: 0.1392 - accuracy: 0.9553 - val_loss: 0.1340 - val_accuracy: 0.9588
Epoch 7/10
327/327 [==============================] - 26s 79ms/step - loss: 0.1203 - accuracy: 0.9609 - val_loss: 0.1131 - val_accuracy: 0.9632
Epoch 8/10
327/327 [==============================] - 25s 76ms/step - loss: 0.1128 - accuracy: 0.9644 - val_loss: 0.1170 - val_accuracy: 0.9644
Epoch 9/10
327/327 [==============================] - 27s 81ms/step - loss: 0.1061 - accuracy: 0.9663 - val_loss: 0.1051 - val_accuracy: 0.9659
Epoch 10/10
327/327 [==============================] - 29s 89ms/step - loss: 0.0968 - accuracy: 0.9699 - val_loss: 0.0950 - val_accuracy: 0.9705

当我运行model.evaluate(test)时,我得到了高损失和低准确性。

10/10 [==============================] - 4s 200ms/step - loss: 9.2694 - accuracy: 0.0656

有什么原因吗?

42fyovps

42fyovps1#

似乎没有什么明显的错误。在test_ds中尝试设置shuffle=False。为了得到一个线索,尝试在val_ds上运行model.evaluate,看看它是否给出正确的结果。我唯一能想到的就是测试数据有问题。看看几张图片,看看它们的相关标签是否正确。

smdnsysy

smdnsysy2#

训练数据集不完整

看起来你有一个不完整的数据集。正如您在加载文件后的输出中所看到的,它说(引用您问题中的输出):

Found 40818 files belonging to 7 classes.
Using 32655 files for training.
Found 40818 files belonging to 7 classes.
Using 8163 files for validation.
Found 10000 files belonging to 10 classes.

请注意,前两个是训练和验证数据集,它们只看到40818个文件,总共有7个类,而最后一个是测试数据集,它看到了所有10个类。这意味着你只训练了7个类,而你的模型从未见过其他3个类。
如果我运行以下代码(这些是我的Jupyter笔记本中的单独单元格,您可以将其粘贴到Colab中以轻松运行),它会找到所有10个类:

%%bash

MNIST_PNG="mnist_png.tar.gz"
if ! [ -e "${MNIST_PNG}" ]; then
  curl -sO "https://raw.githubusercontent.com/myleott/mnist_png/master/${MNIST_PNG}"
fi

MNIST_DIR="mnist_png"
if ! [ -d "${MNIST_DIR}" ]; then
  tar zxf "${MNIST_PNG}"
fi
import tensorflow as tf

train_ds = tf.keras.utils.image_dataset_from_directory(
  'mnist_png/training/',
  validation_split = 0.2,
  subset = "training",
  seed = 123,
  image_size = (32, 32),
  batch_size = 100)

val_ds = tf.keras.utils.image_dataset_from_directory(
  'mnist_png/training/',
  validation_split = 0.2,
  subset = "validation",
  seed = 123,
  image_size = (32, 32),
  batch_size = 100)

test_ds = tf.keras.utils.image_dataset_from_directory(
  'mnist_png/testing/',
  seed = 123,
  image_size = (32, 32),
  batch_size = 1000)

输出:

Found 60000 files belonging to 10 classes.
Using 48000 files for training.
Found 60000 files belonging to 10 classes.
Using 12000 files for validation.
Found 10000 files belonging to 10 classes.

因此,您应该首先解决这个问题,并确保您拥有完整的数据集,然后您可能会在测试数据集上获得更好的结果。

灰度对比RGB图像

我还建议在调用image_dataset_from_directory()时指定color_mode='grayscale',因为数据集本身是灰度的(您可以使用PILmatplotlib库来验证这一点),image_dataset_from_directory()的默认设置是将每个图像放大到RGB(3个通道),这就是您最终获得3通道输入的方式,这只是将单个灰度通道复制3次。

相关问题