keras VAE重建图像非常模糊

tf7tbtn2  于 2023-10-19  发布在  其他
关注(0)|答案(1)|浏览(175)

我对机器学习非常陌生,并从Keras VAE代码示例中构建了一个VAE。我只改变了模型中的几层。我在Kaggle猫和狗数据集上训练了模型,然后尝试重建一些图像。所有重建的图像看起来都一样,就像这些Reconstructed Images。可能是什么原因造成的?这是由于模型不好、训练时间短还是我在重建图像时犯了错误?
编码器型号:

latent_dim = 2
encoder_inputs = keras.Input(shape=(328, 328, 3))
x = layers.Conv2D(32, 3, strides=2, padding="same")(encoder_inputs)
x = layers.Activation("relu")(x)
x = layers.BatchNormalization()(x)
x = layers.Conv2D(64, 3,strides=2, padding="same")(x)
x = layers.Activation("relu")(x)
x = layers.BatchNormalization()(x)
x = layers.Conv2D(128, 3,strides=2, padding="same")(x)  #neu
x = layers.Activation("relu")(x)
x = layers.BatchNormalization()(x)
x = layers.Flatten()(x)
x = layers.Dense(16, activation="relu")(x)
z_mean = layers.Dense(latent_dim, name="z_mean")(x)
z_log_var = layers.Dense(latent_dim, name="z_log_var")(x)
z = Sampling()([z_mean, z_log_var])
encoder = keras.Model(encoder_inputs, [z_mean, z_log_var, z], name="encoder")
encoder.summary()

解码器型号:

x = layers.Dense(41 * 41 * 128, activation="relu")(latent_inputs)
x = layers.Reshape((41, 41, 128))(x)
x = layers.Conv2DTranspose(128, 3, activation="relu", strides=2, padding="same")(x)
x = layers.BatchNormalization()(x)
x = layers.Conv2DTranspose(64, 3, activation="relu", strides=2, padding="same")(x)
x = layers.BatchNormalization()(x)
x = layers.Conv2DTranspose(32, 3, activation="relu", strides=2, padding="same")(x)
x = layers.BatchNormalization()(x)
decoder_outputs = layers.Conv2DTranspose(3, 3, activation="sigmoid", padding="same")(x)
decoder = keras.Model(latent_inputs, decoder_outputs, name="decoder")
decoder.summary()

培训内容:

train_data_dir ='/content/PetImages'
nb_train_samples = 200
nb_epoch = 50
batch_size = 32
img_width = 328
img_height = 328

def fixed_generator(generator):
    for batch in generator:
        yield (batch, batch)

train_datagen = ImageDataGenerator(
        rescale=1./255,
        )

train_generator = train_datagen.flow_from_directory(
        train_data_dir,
        target_size=(img_width, img_height),
        batch_size=batch_size,
        class_mode=None)

vae = VAE(encoder, decoder)
vae.compile(optimizer=keras.optimizers.Adam())
vae.fit(
        fixed_generator(train_generator),
        steps_per_epoch=nb_train_samples,
        epochs=nb_epoch,
        )

重建图像:

import matplotlib.pyplot as plt

test2_datagen = ImageDataGenerator(rescale=1./255)

test2_generator = test2_datagen.flow_from_directory(
        train_data_dir,
        target_size=(img_width, img_height),
        batch_size=10,
        class_mode=None)

sample_img = next(test2_generator)

z_points = vae.encoder.predict(sample_img)

reconst_images = vae.decoder.predict(z_points)

fig = plt.figure(figsize=(10, 8))
fig.subplots_adjust(hspace=0.1, wspace=0.1)

n_to_show =2

for i in range(n_to_show):
    img = sample_img[i].squeeze()
    sub = fig.add_subplot(2, n_to_show, i+1)
    sub.axis('off')        
    sub.imshow(img)

for i in range(n_to_show):
    img = reconst_images[i].squeeze()
    sub = fig.add_subplot(2, n_to_show, i+n_to_show+1)
    sub.axis('off')
    sub.imshow(img)
g52tjvyc

g52tjvyc1#

尝试使用预训练的分类特征提取器,而不是尝试从头开始训练它。只是训练解码器的那一半。

相关问题