keras 关于Tensorflow:模型不学习任何东西,常数损失,不稳定的准确性,验证准确性精确0

scyqe7ek  于 2023-06-30  发布在  其他
关注(0)|答案(1)|浏览(95)

我试图训练一个小的网络来检测JPEG伪影,我有7个类,每个类有1k相同的图像,只是有不同程度的伪影。我的代码看起来像这样:

import os

import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf

from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.models import Sequential

import pathlib

from tensorflow.python.client import device_lib
print(device_lib.list_local_devices())

data_dir = "C:/Users/jilek/Downloads/AAT"
data_dir = pathlib.Path(data_dir).with_suffix('')

image_count = len(list(data_dir.glob('*/*.jpg')))
print(image_count)

batch_size = 1
img_height = 1024
img_width = 1024

train_ds = tf.keras.utils.image_dataset_from_directory(
    data_dir,
    validation_split=0.2,
    subset="training",
    seed=123,
    color_mode="grayscale",
    shuffle=True,
    image_size=(img_height, img_width),
    batch_size=batch_size)

val_ds = tf.keras.utils.image_dataset_from_directory(
    data_dir,
    validation_split=0.2,
    subset="validation",
    seed=123,
    color_mode="grayscale",
    image_size=(img_height, img_width),
    batch_size=batch_size)

class_names = train_ds.class_names
print(class_names)

for image_batch, labels_batch in train_ds:
    print(image_batch.shape)
    print(labels_batch.shape)
    break

AUTOTUNE = tf.data.AUTOTUNE

data_augmentation = keras.Sequential(
    [
        layers.RandomFlip("horizontal",
                          input_shape=(img_height,
                                       img_width,
                                       1)),
        layers.RandomRotation(0.5),
        layers.RandomZoom(0.2),
    ]
)

train_ds = train_ds.shuffle(buffer_size=1000).prefetch(buffer_size=AUTOTUNE) #.cache()
val_ds = val_ds.prefetch(buffer_size=AUTOTUNE) #.cache()

num_classes = len(class_names)

model = Sequential([
  layers.Rescaling(1.0/255, input_shape=(img_height, img_width, 1)),
  layers.Conv2D(2, (4, 4), strides=(4, 4), padding='valid', dilation_rate=(1, 1), groups=1, input_shape=(1024, 1024, 1), activation='relu'),
  layers.Conv2D(4, (4, 4), strides=(4, 4), padding='valid', dilation_rate=(1, 1), groups=1, input_shape=(256, 256, 2), activation='relu'),
  layers.Conv2D(8, (4, 4), strides=(4, 4), padding='valid', dilation_rate=(1, 1), groups=1, input_shape=(64, 64, 4), activation='relu'),
  layers.Conv2D(16, (4, 4), strides=(4, 4), padding='valid', dilation_rate=(1, 1), groups=1, input_shape=(16, 16, 8), activation='relu'),
  layers.Flatten(),
  layers.Dense(64, activation='relu'),
  layers.Dense(16, activation='relu'),
  layers.Dense(num_classes, activation='softmax')
])

model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])

model.summary()
model.save("./model/AAT")

epochs = 100
history = model.fit(
  train_ds,
  validation_data=val_ds,
  epochs=epochs
)

acc = history.history['accuracy']
val_acc = history.history['val_accuracy']

loss = history.history['loss']
val_loss = history.history['val_loss']

epochs_range = range(epochs)

plt.figure(figsize=(8, 8))
plt.subplot(1, 2, 1)
plt.plot(epochs_range, acc, label='Training Accuracy')
plt.plot(epochs_range, val_acc, label='Validation Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')

plt.subplot(1, 2, 2)
plt.plot(epochs_range, loss, label='Training Loss')
plt.plot(epochs_range, val_loss, label='Validation Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()

test_dir = "C:/Users/jilek/Downloads/AAT_T/"
for file_name in os.listdir(test_dir):
    file_path = os.path.join(test_dir, file_name)
    img = tf.keras.utils.load_img(
        file_path, target_size=(img_height, img_width), color_mode="grayscale"
    )
    img_array = tf.keras.utils.img_to_array(img)
    img_array = tf.expand_dims(img_array, 0)  # Create a batch

    predictions = model.predict(img_array)
    score = tf.nn.softmax(predictions[0])

    print(file_name)
    print(
        "This image most likely belongs to {} with a {:.2f} percent confidence."
        .format(class_names[np.argmax(score)], 100 * np.max(score))
    )

我试着玩层,优化器甚至损失lossfunctions,但所有的时间与完全相同的结果或错误。日志如下所示:

5768/5768 [==============================] - 40s 6ms/step - loss: 1.9472 - accuracy: 0.1954 - val_loss: 1.9467 - val_accuracy: 0.0000e+00
Epoch 2/100
5768/5768 [==============================] - 34s 6ms/step - loss: 1.9468 - accuracy: 0.0381 - val_loss: 1.9473 - val_accuracy: 0.0000e+00
Epoch 3/100
5768/5768 [==============================] - 35s 6ms/step - loss: 1.9468 - accuracy: 0.3358 - val_loss: 1.9469 - val_accuracy: 0.0000e+00
Epoch 4/100
5768/5768 [==============================] - 32s 5ms/step - loss: 1.9468 - accuracy: 0.1555 - val_loss: 1.9468 - val_accuracy: 0.0000e+00
Epoch 5/100
5768/5768 [==============================] - 32s 5ms/step - loss: 1.9468 - accuracy: 0.1130 - val_loss: 1.9466 - val_accuracy: 0.0000e+00
Epoch 6/100
5768/5768 [==============================] - 34s 6ms/step - loss: 1.9467 - accuracy: 0.2715 - val_loss: 1.9468 - val_accuracy: 0.0000e+00
Epoch 7/100
5768/5768 [==============================] - 31s 5ms/step - loss: 1.9469 - accuracy: 0.2384 - val_loss: 1.9471 - val_accuracy: 0.0000e+00
Epoch 8/100
5768/5768 [==============================] - 32s 5ms/step - loss: 1.9469 - accuracy: 0.0836 - val_loss: 1.9465 - val_accuracy: 0.0000e+00

在它完成之后,这里是对数据集中未包含的图像的测试:

831180_005.jpg
This image most likely belongs to C060 with a 14.34 percent confidence.
1/1 [==============================] - 0s 17ms/step
831180_010.jpg
This image most likely belongs to C060 with a 14.34 percent confidence.
1/1 [==============================] - 0s 19ms/step
831180_020.jpg
This image most likely belongs to C060 with a 14.34 percent confidence.
1/1 [==============================] - 0s 17ms/step
831180_040.jpg
This image most likely belongs to C060 with a 14.34 percent confidence.
1/1 [==============================] - 0s 21ms/step
831180_060.jpg
This image most likely belongs to C060 with a 14.34 percent confidence.
1/1 [==============================] - 0s 19ms/step
831180_080.jpg
This image most likely belongs to C060 with a 14.34 percent confidence.
1/1 [==============================] - 0s 17ms/step
831180_100.jpg
This image most likely belongs to C060 with a 14.34 percent confidence.
7vux5j2d

7vux5j2d1#

我相信你使用的损失函数不适合你的问题。据我所知,你的输出是一个独热编码格式,而不是稀疏格式。我相信一个简单的“CrossEntropyLoss”应该可以工作

相关问题