pytorch log_prob和手动计算之间的差异

inkz8wg9  于 2023-03-02  发布在  其他
关注(0)|答案(1)|浏览(168)

我想用平均值[1, 1, 1]和方差协方差矩阵0.3定义多元正态分布,然后计算数据点[2, 3, 4]的对数似然

按 Torch 分布

import torch
import torch.distributions as td

input_x = torch.tensor([2, 3, 4])
loc = torch.ones(3)
scale = torch.eye(3) * 0.3
mvn = td.MultivariateNormal(loc = loc, scale_tril=scale)
mvn.log_prob(input_x)
tensor(-76.9227)

从零开始

通过使用对数似然公式:

我们得到Tensor:

first_term = (2 * np.pi* 0.3)**(3)
first_term = -np.log(np.sqrt(first_term))
x_center = input_x - loc
tmp = torch.matmul(x_center, scale.inverse())
tmp = -1/2 * torch.matmul(tmp, x_center)
first_term + tmp 
tensor(-24.2842)

我用了

我的问题是--这种差异的来源是什么?

nle07wnf

nle07wnf1#

您正在将协方差矩阵传递给scale_tril,而不是covariance_matrix。来自PyTorch的多元正态分布文档
scale_tril(Tensor)-协方差的下三角因子,对角线为正值
因此,将scale_tril替换为covariance_matrix将产生与手动尝试相同的结果。

In [1]: mvn = td.MultivariateNormal(loc = loc, covariance_matrix=scale)
In [2]: mvn.log_prob(input_x)
Out[2]: tensor(-24.2842)

然而,根据作者的说法,使用scale_tril更有效:
...使用scale_tril会更高效:
可以使用torch.linalg.cholesky计算下切尔斯基

In [3]: mvn = td.MultivariateNormal(loc = loc, scale_tril=torch.linalg.cholesky(scale))
In [4]: mvn.log_prob(input_x)
Out[4]: tensor(-24.2842)

相关问题