如果我有n个大小的均值向量,n个大小的方差向量,那么我该怎么做呢?
z ∼ N (μ, σ)
import torch x = torch.randn(3, 3) mu = x.mean() sigma = x.var()
我该怎么做才能得到z?
lymnna711#
如果你想从平均值为mu和 stdsigma的正态分布中采样,那么你可以简单地
mu
sigma
z = torch.randn_like(mu) * sigma + mu
如果你对许多这样的z进行采样,它们的均值和标准差将收敛到sigma和mu:
z
mu = torch.arange(10.) Out[]: tensor([0., 1., 2., 3., 4., 5., 6., 7., 8., 9.]) sigma = 5. - 0.5 * torch.arange(10.) Out[]: tensor([5.0000, 4.5000, 4.0000, 3.5000, 3.0000, 2.5000, 2.0000, 1.5000, 1.0000, 0.5000]) z = torch.randn(10, 1000000) * sigma[:, None] + mu[:, None] z.mean(dim=1) Out[]: tensor([-5.4823e-03, 1.0011e+00, 1.9982e+00, 2.9985e+00, 4.0017e+00, 4.9972e+00, 6.0010e+00, 7.0004e+00, 7.9996e+00, 9.0006e+00]) z.std(dim=1) Out[]: tensor([4.9930, 4.4945, 4.0021, 3.5013, 3.0005, 2.4986, 1.9997, 1.4998, 0.9990, 0.5001])
正如您所看到的,当您从分布中采样1,000,000个元素时,样本均值和标准差接近您开始使用的原始mu和sigma。
ippsafx72#
上面描述的方法是正确的,但另一个好方法是使用torch.normal()。下面是从具有Tensor均值和方差的正态分布中抽样的示例。示例如下所示。
torch.normal()
import torch a = torch.tensor([1, 2, 3, 4 * 2, 2]) b = torch.tensor([5, 6, 7, 8 * 2, 2]) c = torch.tensor([5, 6, 7, 8 * 2, 2]) d = torch.tensor([9, 10, 11, 12 * 2, 2]) vectors = torch.stack([a, b, c, d]).float() mean_vector = vectors.mean(dim=0) std_vector = vectors.std(dim=0) sampled1 = torch.normal(mean=mean_vector, std=std_vector) # <-- sampled2 = torch.normal(mean=mean_vector, std=std_vector) print("mean_vector:", mean_vector) print("std_vector:", std_vector) print(sampled1) print(sampled2)
mean_vector:Tensor([ 5、6、7、16、2])std_vector:tensor([3.2660,3.2660,3.2660,6.5320,0.0000])Tensor([ 3.5221,-0.0818,5.1410,13.2573,2.0000])Tensor([ 6.2393,6.0634,2.4302,20.4233,2.0000])
2条答案
按热度按时间lymnna711#
如果你想从平均值为
mu
和 stdsigma
的正态分布中采样,那么你可以简单地如果你对许多这样的
z
进行采样,它们的均值和标准差将收敛到sigma
和mu
:正如您所看到的,当您从分布中采样1,000,000个元素时,样本均值和标准差接近您开始使用的原始
mu
和sigma
。ippsafx72#
上面描述的方法是正确的,但另一个好方法是使用
torch.normal()
。下面是从具有Tensor均值和方差的正态分布中抽样的示例。
示例如下所示。
PyTorch
结果
mean_vector:Tensor([ 5、6、7、16、2])
std_vector:tensor([3.2660,3.2660,3.2660,6.5320,0.0000])
Tensor([ 3.5221,-0.0818,5.1410,13.2573,2.0000])
Tensor([ 6.2393,6.0634,2.4302,20.4233,2.0000])