我对机器学习非常陌生,并从Keras VAE代码示例中构建了一个VAE。我只改变了模型中的几层。我在Kaggle猫和狗数据集上训练了模型,然后尝试重建一些图像。所有重建的图像看起来都一样,就像这些Reconstructed Images。可能是什么原因造成的?这是由于模型不好、训练时间短还是我在重建图像时犯了错误?
编码器型号:
latent_dim = 2
encoder_inputs = keras.Input(shape=(328, 328, 3))
x = layers.Conv2D(32, 3, strides=2, padding="same")(encoder_inputs)
x = layers.Activation("relu")(x)
x = layers.BatchNormalization()(x)
x = layers.Conv2D(64, 3,strides=2, padding="same")(x)
x = layers.Activation("relu")(x)
x = layers.BatchNormalization()(x)
x = layers.Conv2D(128, 3,strides=2, padding="same")(x) #neu
x = layers.Activation("relu")(x)
x = layers.BatchNormalization()(x)
x = layers.Flatten()(x)
x = layers.Dense(16, activation="relu")(x)
z_mean = layers.Dense(latent_dim, name="z_mean")(x)
z_log_var = layers.Dense(latent_dim, name="z_log_var")(x)
z = Sampling()([z_mean, z_log_var])
encoder = keras.Model(encoder_inputs, [z_mean, z_log_var, z], name="encoder")
encoder.summary()
解码器型号:
x = layers.Dense(41 * 41 * 128, activation="relu")(latent_inputs)
x = layers.Reshape((41, 41, 128))(x)
x = layers.Conv2DTranspose(128, 3, activation="relu", strides=2, padding="same")(x)
x = layers.BatchNormalization()(x)
x = layers.Conv2DTranspose(64, 3, activation="relu", strides=2, padding="same")(x)
x = layers.BatchNormalization()(x)
x = layers.Conv2DTranspose(32, 3, activation="relu", strides=2, padding="same")(x)
x = layers.BatchNormalization()(x)
decoder_outputs = layers.Conv2DTranspose(3, 3, activation="sigmoid", padding="same")(x)
decoder = keras.Model(latent_inputs, decoder_outputs, name="decoder")
decoder.summary()
培训内容:
train_data_dir ='/content/PetImages'
nb_train_samples = 200
nb_epoch = 50
batch_size = 32
img_width = 328
img_height = 328
def fixed_generator(generator):
for batch in generator:
yield (batch, batch)
train_datagen = ImageDataGenerator(
rescale=1./255,
)
train_generator = train_datagen.flow_from_directory(
train_data_dir,
target_size=(img_width, img_height),
batch_size=batch_size,
class_mode=None)
vae = VAE(encoder, decoder)
vae.compile(optimizer=keras.optimizers.Adam())
vae.fit(
fixed_generator(train_generator),
steps_per_epoch=nb_train_samples,
epochs=nb_epoch,
)
重建图像:
import matplotlib.pyplot as plt
test2_datagen = ImageDataGenerator(rescale=1./255)
test2_generator = test2_datagen.flow_from_directory(
train_data_dir,
target_size=(img_width, img_height),
batch_size=10,
class_mode=None)
sample_img = next(test2_generator)
z_points = vae.encoder.predict(sample_img)
reconst_images = vae.decoder.predict(z_points)
fig = plt.figure(figsize=(10, 8))
fig.subplots_adjust(hspace=0.1, wspace=0.1)
n_to_show =2
for i in range(n_to_show):
img = sample_img[i].squeeze()
sub = fig.add_subplot(2, n_to_show, i+1)
sub.axis('off')
sub.imshow(img)
for i in range(n_to_show):
img = reconst_images[i].squeeze()
sub = fig.add_subplot(2, n_to_show, i+n_to_show+1)
sub.axis('off')
sub.imshow(img)
1条答案
按热度按时间g52tjvyc1#
尝试使用预训练的分类特征提取器,而不是尝试从头开始训练它。只是训练解码器的那一半。