pytorch 什么会导致VAE(变分自动编码器)即使在训练后也输出随机噪声?

bihw5rsg  于 2023-10-20  发布在  其他
关注(0)|答案(1)|浏览(104)

我在CIFAR10数据集上训练了一个VAE。然而,当我试图从VAE生成图像时,我得到的只是一堆灰色噪声。此VAE的实现遵循Generative Deep Learning一书中的实现,但代码使用PyTorch而不是TensorFlow。
包含训练和生成的笔记本可以在here中找到,而VAE的实际实现可以在here中找到。
我试过:
1.令人失去能力的辍学者。
1.增加潜在空间的维度。
这些方法都没有显示出任何改进。
本人已证实:
1.输入大小与输出大小匹配
1.当训练过程中损失减少时,反向传播成功运行。

r7knjye2

r7knjye21#

感谢您提供代码和指向Colab笔记本的链接!+1!此外,您的代码编写得很好,易于阅读。除非我错过了什么,我认为你的代码有两个问题:
1.所述数据标准化
1.执行VAE损失。

关于1.,您的CIFAR10DataModule类使用mean = 0.5std = 0.5规范化CIFAR 10图像的RGB通道。由于像素值最初在[0,1]范围内,因此归一化图像具有在[-1,1]范围内的像素值。但是,您的Decoder类将nn.Sigmoid()激活应用于重建图像。因此,重建图像的像素值在[0,1]范围内。我建议删除这个均值标准化,这样“真实”图像和重建图像的像素值都在[0,1]范围内。
约2.:因为你处理的是RGB图像,所以MSE损失是有道理的。MSE损失背后的想法是“高斯解码器”。该解码器假设“真实图像”的像素值是由独立的高斯分布生成的,其平均值是重建图像的像素值(即,解码器的输出)并且具有给定的方差。您对重建损失(即r_loss = F.mse_loss(predictions, targets))的实现相当于一个固定方差。利用this paper的思想,我们可以做得更好,并获得这个方差参数的“最优值”的解析表达式。最后,重建损失应该在所有像素上求和(reduction = 'sum')。为了理解为什么,看看重建损失的解析表达式(例如,参见考虑BCE损失的this blog post)。

下面是重构后的LitVAE类:

class LitVAE(pl.LightningModule):
    def __init__(self,
                 learning_rate: float = 0.0005,
                 **kwargs) -> None:
        """
        Parameters
        ----------
        - `learning_rate: float`:
            learning rate for the optimizer
        - `**kwargs`:
            arguments to pass to the variational autoencoder constructor
        """
        super(LitVAE, self).__init__()
        
        self.learning_rate = learning_rate 

        self.vae = VariationalAutoEncoder(**kwargs)

    def forward(self, x) -> _tensor_size_3_t: 
        return self.vae(x)

    def training_step(self, batch, batch_idx):
        r_loss, kl_loss, sigma_opt = self.shared_step(batch)
        loss = r_loss + kl_loss
        
        self.log("train_loss_step", loss)
        return {"loss": loss, 'log':{"r_loss": r_loss / len(batch[0]), "kl_loss": kl_loss / len(batch[0]), 'sigma_opt': sigma_opt}}

    def training_epoch_end(self, outputs) -> None:
        # add computation graph
        if(self.current_epoch == 0):
            sample_input = torch.randn((1, 3, 32, 32))
            sample_model = LitVAE(**MODEL_PARAMS)
            
            self.logger.experiment.add_graph(sample_model, sample_input)
            
        epoch_loss = self.average_metric(outputs, "loss")
        self.logger.experiment.add_scalar("train_loss_epoch", epoch_loss, self.current_epoch)

    def validation_step(self, batch, batch_idx):
        r_loss, kl_loss, _ = self.shared_step(batch)
        loss = r_loss + kl_loss

        self.log("valid_loss_step", loss)

        return {"loss": loss}

    def validation_epoch_end(self, outputs) -> None:
        epoch_loss = self.average_metric(outputs, "loss")
        self.logger.experiment.add_scalar("valid_loss_epoch", epoch_loss, self.current_epoch)

    def test_step(self, batch, batch_idx):
        r_loss, kl_loss, _ = self.shared_step(batch)
        loss = r_loss + kl_loss
        
        self.log("test_loss_step", loss)
        return {"loss": loss}

    def test_epoch_end(self, outputs) -> None:
        epoch_loss = self.average_metric(outputs, "loss")
        self.logger.experiment.add_scalar("test_loss_epoch", epoch_loss, self.current_epoch)

    def configure_optimizers(self):
        return optim.Adam(self.parameters(), lr=self.learning_rate)
        
    def shared_step(self, batch) -> torch.TensorType: 
        # images are both samples and targets thus original 
        # labels from the dataset are not required
        true_images, _ = batch

        # perform a forward pass through the VAE 
        # mean and log_variance are used to calculate the KL Divergence loss 
        # decoder_output represents the generated images 
        mean, log_variance, generated_images = self(true_images)

        r_loss, kl_loss, sigma_opt = self.calculate_loss(mean, log_variance, generated_images, true_images)
        return r_loss, kl_loss, sigma_opt

    def calculate_loss(self, mean, log_variance, predictions, targets):
        mse = F.mse_loss(predictions, targets, reduction='mean')
        log_sigma_opt = 0.5 * mse.log()
        r_loss = 0.5 * torch.pow((targets - predictions) / log_sigma_opt.exp(), 2) + log_sigma_opt
        r_loss = r_loss.sum()
        kl_loss = self._compute_kl_loss(mean, log_variance)
        return r_loss, kl_loss, log_sigma_opt.exp()

    def _compute_kl_loss(self, mean, log_variance): 
        return -0.5 * torch.sum(1 + log_variance - mean.pow(2) - log_variance.exp())

    def average_metric(self, metrics, metric_name):
        avg_metric = torch.stack([x[metric_name] for x in metrics]).mean()
        return avg_metric

经过10个时期,这就是重建图像的样子:

相关问题