我试图训练一个小的网络来检测JPEG伪影,我有7个类,每个类有1k相同的图像,只是有不同程度的伪影。我的代码看起来像这样:
import os
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.models import Sequential
import pathlib
from tensorflow.python.client import device_lib
print(device_lib.list_local_devices())
data_dir = "C:/Users/jilek/Downloads/AAT"
data_dir = pathlib.Path(data_dir).with_suffix('')
image_count = len(list(data_dir.glob('*/*.jpg')))
print(image_count)
batch_size = 1
img_height = 1024
img_width = 1024
train_ds = tf.keras.utils.image_dataset_from_directory(
data_dir,
validation_split=0.2,
subset="training",
seed=123,
color_mode="grayscale",
shuffle=True,
image_size=(img_height, img_width),
batch_size=batch_size)
val_ds = tf.keras.utils.image_dataset_from_directory(
data_dir,
validation_split=0.2,
subset="validation",
seed=123,
color_mode="grayscale",
image_size=(img_height, img_width),
batch_size=batch_size)
class_names = train_ds.class_names
print(class_names)
for image_batch, labels_batch in train_ds:
print(image_batch.shape)
print(labels_batch.shape)
break
AUTOTUNE = tf.data.AUTOTUNE
data_augmentation = keras.Sequential(
[
layers.RandomFlip("horizontal",
input_shape=(img_height,
img_width,
1)),
layers.RandomRotation(0.5),
layers.RandomZoom(0.2),
]
)
train_ds = train_ds.shuffle(buffer_size=1000).prefetch(buffer_size=AUTOTUNE) #.cache()
val_ds = val_ds.prefetch(buffer_size=AUTOTUNE) #.cache()
num_classes = len(class_names)
model = Sequential([
layers.Rescaling(1.0/255, input_shape=(img_height, img_width, 1)),
layers.Conv2D(2, (4, 4), strides=(4, 4), padding='valid', dilation_rate=(1, 1), groups=1, input_shape=(1024, 1024, 1), activation='relu'),
layers.Conv2D(4, (4, 4), strides=(4, 4), padding='valid', dilation_rate=(1, 1), groups=1, input_shape=(256, 256, 2), activation='relu'),
layers.Conv2D(8, (4, 4), strides=(4, 4), padding='valid', dilation_rate=(1, 1), groups=1, input_shape=(64, 64, 4), activation='relu'),
layers.Conv2D(16, (4, 4), strides=(4, 4), padding='valid', dilation_rate=(1, 1), groups=1, input_shape=(16, 16, 8), activation='relu'),
layers.Flatten(),
layers.Dense(64, activation='relu'),
layers.Dense(16, activation='relu'),
layers.Dense(num_classes, activation='softmax')
])
model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
model.summary()
model.save("./model/AAT")
epochs = 100
history = model.fit(
train_ds,
validation_data=val_ds,
epochs=epochs
)
acc = history.history['accuracy']
val_acc = history.history['val_accuracy']
loss = history.history['loss']
val_loss = history.history['val_loss']
epochs_range = range(epochs)
plt.figure(figsize=(8, 8))
plt.subplot(1, 2, 1)
plt.plot(epochs_range, acc, label='Training Accuracy')
plt.plot(epochs_range, val_acc, label='Validation Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')
plt.subplot(1, 2, 2)
plt.plot(epochs_range, loss, label='Training Loss')
plt.plot(epochs_range, val_loss, label='Validation Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()
test_dir = "C:/Users/jilek/Downloads/AAT_T/"
for file_name in os.listdir(test_dir):
file_path = os.path.join(test_dir, file_name)
img = tf.keras.utils.load_img(
file_path, target_size=(img_height, img_width), color_mode="grayscale"
)
img_array = tf.keras.utils.img_to_array(img)
img_array = tf.expand_dims(img_array, 0) # Create a batch
predictions = model.predict(img_array)
score = tf.nn.softmax(predictions[0])
print(file_name)
print(
"This image most likely belongs to {} with a {:.2f} percent confidence."
.format(class_names[np.argmax(score)], 100 * np.max(score))
)
我试着玩层,优化器甚至损失lossfunctions,但所有的时间与完全相同的结果或错误。日志如下所示:
5768/5768 [==============================] - 40s 6ms/step - loss: 1.9472 - accuracy: 0.1954 - val_loss: 1.9467 - val_accuracy: 0.0000e+00
Epoch 2/100
5768/5768 [==============================] - 34s 6ms/step - loss: 1.9468 - accuracy: 0.0381 - val_loss: 1.9473 - val_accuracy: 0.0000e+00
Epoch 3/100
5768/5768 [==============================] - 35s 6ms/step - loss: 1.9468 - accuracy: 0.3358 - val_loss: 1.9469 - val_accuracy: 0.0000e+00
Epoch 4/100
5768/5768 [==============================] - 32s 5ms/step - loss: 1.9468 - accuracy: 0.1555 - val_loss: 1.9468 - val_accuracy: 0.0000e+00
Epoch 5/100
5768/5768 [==============================] - 32s 5ms/step - loss: 1.9468 - accuracy: 0.1130 - val_loss: 1.9466 - val_accuracy: 0.0000e+00
Epoch 6/100
5768/5768 [==============================] - 34s 6ms/step - loss: 1.9467 - accuracy: 0.2715 - val_loss: 1.9468 - val_accuracy: 0.0000e+00
Epoch 7/100
5768/5768 [==============================] - 31s 5ms/step - loss: 1.9469 - accuracy: 0.2384 - val_loss: 1.9471 - val_accuracy: 0.0000e+00
Epoch 8/100
5768/5768 [==============================] - 32s 5ms/step - loss: 1.9469 - accuracy: 0.0836 - val_loss: 1.9465 - val_accuracy: 0.0000e+00
在它完成之后,这里是对数据集中未包含的图像的测试:
831180_005.jpg
This image most likely belongs to C060 with a 14.34 percent confidence.
1/1 [==============================] - 0s 17ms/step
831180_010.jpg
This image most likely belongs to C060 with a 14.34 percent confidence.
1/1 [==============================] - 0s 19ms/step
831180_020.jpg
This image most likely belongs to C060 with a 14.34 percent confidence.
1/1 [==============================] - 0s 17ms/step
831180_040.jpg
This image most likely belongs to C060 with a 14.34 percent confidence.
1/1 [==============================] - 0s 21ms/step
831180_060.jpg
This image most likely belongs to C060 with a 14.34 percent confidence.
1/1 [==============================] - 0s 19ms/step
831180_080.jpg
This image most likely belongs to C060 with a 14.34 percent confidence.
1/1 [==============================] - 0s 17ms/step
831180_100.jpg
This image most likely belongs to C060 with a 14.34 percent confidence.
1条答案
按热度按时间7vux5j2d1#
我相信你使用的损失函数不适合你的问题。据我所知,你的输出是一个独热编码格式,而不是稀疏格式。我相信一个简单的“CrossEntropyLoss”应该可以工作