Scipy多变量正态分布:如何抽取确定性样本?

xn1cxnb4  于 2022-11-10  发布在  其他
关注(0)|答案(1)|浏览(190)

我使用Scipy.stats.multivariate_normal从多元正态分布中抽取样本。如下所示:

from scipy.stats import multivariate_normal

# Assume we have means and covs

mn = multivariate_normal(mean = means, cov = covs)

# Generate some samples

samples = mn.rvs()

每次运行的样本都不一样。我如何总是得到相同的样本?我期待的是类似这样的结果:

mn = multivariate_normal(mean = means, cov = covs, seed = aNumber)

samples = mn.rsv(seed = aNumber)
osh3o9ms

osh3o9ms1#

有两种方法:

  1. rvs()方法接受一个random_state参数,它的值可以是一个整数种子,也可以是numpy.random.Generatornumpy.random.RandomState的一个示例,在这个例子中,我使用了一个整数种子:
In [46]: mn = multivariate_normal(mean=[0,0,0], cov=[1, 5, 25])

 In [47]: mn.rvs(size=5, random_state=12345)
 Out[47]: 
 array([[-0.51943872,  1.07094986, -1.0235383 ],
        [ 1.39340583,  4.39561899, -2.77865152],
        [ 0.76902257,  0.63000355,  0.46453938],
        [-1.29622111,  2.25214387,  6.23217368],
        [ 1.35291684,  0.51186476,  1.37495817]])

 In [48]: mn.rvs(size=5, random_state=12345)
 Out[48]: 
 array([[-0.51943872,  1.07094986, -1.0235383 ],
        [ 1.39340583,  4.39561899, -2.77865152],
        [ 0.76902257,  0.63000355,  0.46453938],
        [-1.29622111,  2.25214387,  6.23217368],
        [ 1.35291684,  0.51186476,  1.37495817]])

此版本使用numpy.random.Generator的示例:

In [34]: rng = np.random.default_rng(438753948759384)

In [35]: mn = multivariate_normal(mean=[0,0,0], cov=[1, 5, 25])

In [36]: mn.rvs(size=5, random_state=rng)
Out[36]: 
array([[ 0.30626179,  0.60742839,  2.86919105],
       [ 1.61859885,  2.63409111,  1.19018398],
       [ 0.35469027,  0.85685011,  6.76892829],
       [-0.88659459, -0.59922575, -5.43926698],
       [ 0.94777687, -5.80057427, -2.16887719]])

1.你可以为numpy的全局随机数生成器设置seed,这是multivariate_normal.rvs()在没有给定random_state的情况下使用的生成器:

In [54]: mn = multivariate_normal(mean=[0,0,0], cov=[1, 5, 25])

 In [55]: np.random.seed(123)

 In [56]: mn.rvs(size=5)
 Out[56]: 
 array([[  0.2829785 ,   2.23013222,  -5.42815302],
        [  1.65143654,  -1.2937895 ,  -7.53147357],
        [  1.26593626,  -0.95907779, -12.13339622],
        [ -0.09470897,  -1.51803558,  -4.33370201],
        [ -0.44398196,  -1.4286283 ,   7.45694813]])

 In [57]: np.random.seed(123)

 In [58]: mn.rvs(size=5)
 Out[58]: 
 array([[  0.2829785 ,   2.23013222,  -5.42815302],
        [  1.65143654,  -1.2937895 ,  -7.53147357],
        [  1.26593626,  -0.95907779, -12.13339622],
        [ -0.09470897,  -1.51803558,  -4.33370201],
        [ -0.44398196,  -1.4286283 ,   7.45694813]])

相关问题