对数组使用numpy.random.normal

1bqhqjot  于 12个月前  发布在  其他
关注(0)|答案(3)|浏览(98)

假设我有以下两个数组,其中包含均值和标准差:

mu = np.array([2000, 3000, 5000, 1000])
sigma = np.array([250, 152, 397, 180])

字符串
然后:

a = np.random.normal(mu, sigma)

In [1]: a
Out[1]: array([1715.6903716 , 3028.54168667, 4731.34048645, 933.18903575])


但是,如果我要求对mu,sigma的每个元素进行100次抽奖:

a = np.random.normal(mu, sigma, 100)

a = np.random.normal(mu, sigma, 100)
Traceback (most recent call last):

File "<ipython-input-417-4aadd7d15875>", line 1, in <module>
a = np.random.normal(mu, sigma, 100)

File "mtrand.pyx", line 1652, in mtrand.RandomState.normal

File "mtrand.pyx", line 265, in mtrand.cont2_array

ValueError: shape mismatch: objects cannot be broadcast to a single shape


我也尝试过使用元组来表示大小:

s = (100, 100, 100, 100)
a = np.random.normal(mu, sigma, s)


我错过了什么?

35g0bw71

35g0bw711#

我不相信当你传递一个mean和std的值的列表/向量时,你可以控制size参数。相反,你可以遍历每一对,然后连接:

np.concatenate(
   [np.random.normal(m, s, 100) for m, s in zip(mu, sigma)]
)

字符串
这给了你一个(400, )数组。如果你想要一个(4, 100)数组,调用np.array而不是np.concatenate

xsuvu9jc

xsuvu9jc2#

如果你只想做一次调用,正态分布很容易在事后进行移位和重新缩放(我正在从你的例子中制作一个10000长的musigma向量):

mu = np.random.choice([2000., 3000., 5000., 1000.], 10000)               
sigma = np.random.choice([250., 152., 397., 180.], 10000)

a = np.random.normal(size=(10000, 100)) * sigma[:,None] + mu[:,None]

字符串
这很好用。你可以决定速度是否是一个问题。在我的系统上,以下只是慢了50%:

a = np.array([np.random.normal(m, s, 100) for m,s in zip(mu, sigma)])

vkc1a9a2

vkc1a9a23#

这是一个老问题,但我最近遇到了同样的问题,目前文档还不清楚,所以我的回答可能对其他人有用。
问题是,如果你想从具有n_param不同参数的(不相关的)正态分布中绘制n_sample样本,函数的size参数需要是元组(n_sample, n_param)。回到你的例子:

mu = np.array([2000, 3000, 5000, 1000])
sigma = np.array([250, 152, 397, 180])

n_sample = 10
n_param = len(mu)

np.random.normal(mu, sigma, (n_sample, n_param))

字符串
它返回

array([[2048.27840802, 2997.96810385, 4388.76381537,  834.58578664],
       [2284.62302217, 3057.37011582, 5141.42601472,  757.21437687],
       [1933.16814182, 3060.13736788, 5431.56812414,  949.80295487],
       [2444.69699622, 3049.32584965, 4850.82175943,  772.26041345],
       [2129.87928253, 2976.20614441, 5140.33783836, 1017.96741881],
       [1906.47137372, 2829.44037933, 4894.20964032, 1245.29240452],
       [2031.94886175, 2693.19106648, 5385.33674047,  849.72485587],
       [2034.22639971, 3017.86916011, 5050.08920701, 1198.48286148],
       [2278.8297283 , 3036.31308636, 5043.93694099,  988.87438521],
       [1760.04486593, 2875.0750094 , 4615.1775128 ,  946.76458665]])

相关问题