我在CIFAR10数据集上训练了一个VAE。然而,当我试图从VAE生成图像时,我得到的只是一堆灰色噪声。此VAE的实现遵循Generative Deep Learning一书中的实现,但代码使用PyTorch而不是TensorFlow。
包含训练和生成的笔记本可以在here中找到,而VAE的实际实现可以在here中找到。
我试过:
1.令人失去能力的辍学者。
1.增加潜在空间的维度。
这些方法都没有显示出任何改进。
本人已证实:
1.输入大小与输出大小匹配
1.当训练过程中损失减少时,反向传播成功运行。
1条答案
按热度按时间r7knjye21#
感谢您提供代码和指向Colab笔记本的链接!+1!此外,您的代码编写得很好,易于阅读。除非我错过了什么,我认为你的代码有两个问题:
1.所述数据标准化
1.执行VAE损失。
关于1.,您的
CIFAR10DataModule
类使用mean = 0.5
和std = 0.5
规范化CIFAR 10图像的RGB通道。由于像素值最初在[0,1]范围内,因此归一化图像具有在[-1,1]范围内的像素值。但是,您的Decoder
类将nn.Sigmoid()
激活应用于重建图像。因此,重建图像的像素值在[0,1]范围内。我建议删除这个均值标准化,这样“真实”图像和重建图像的像素值都在[0,1]范围内。约2.:因为你处理的是RGB图像,所以MSE损失是有道理的。MSE损失背后的想法是“高斯解码器”。该解码器假设“真实图像”的像素值是由独立的高斯分布生成的,其平均值是重建图像的像素值(即,解码器的输出)并且具有给定的方差。您对重建损失(即
r_loss = F.mse_loss(predictions, targets)
)的实现相当于一个固定方差。利用this paper的思想,我们可以做得更好,并获得这个方差参数的“最优值”的解析表达式。最后,重建损失应该在所有像素上求和(reduction = 'sum'
)。为了理解为什么,看看重建损失的解析表达式(例如,参见考虑BCE损失的this blog post)。下面是重构后的
LitVAE
类:经过10个时期,这就是重建图像的样子: