如何应用函数元素与输入从多个numpy掩码数组来创建一个新的掩码数组?

amrnrhlw  于 2022-11-29  发布在  其他
关注(0)|答案(1)|浏览(158)

我有一个函数,它接受4个单值输入,返回一个单浮点输出,例如:

from scipy.stats import multivariate_normal

grid_step = 0.25 #in units of sigma
grid_x, grid_y = np.mgrid[-2:2+grid_step:grid_step, -2:2+grid_step:grid_step]
pos = np.dstack((grid_x, grid_y))
rv = multivariate_normal([0.0, 0.0], [[1.0, 0], [0, 1.0]])
grid_pdf = rv.pdf(pos)*grid_step**2
norm_pdf = np.sum(rv.pdf(pos))*grid_step**2

def cal_prob(x, x_err, y, y_err):
    x_grid = grid_x*x_err + x
    y_grid = grid_y*y_err + y
    PSB_grid = ((x_grid>3) & (y_grid<10) & (y_grid < 10**(0.23*x_grid-0.46)))
    PSB_prob = np.sum(PSB_grid*grid_pdf)/norm_pdf
    return PSB_prob

此函数的作用是在给定x和y的不确定性的情况下,估计某些x-y测量值在x-y空间中的某个定义限制内的概率。它假设不确定性是高斯的且不相关。然后,使用预先设置的grid_pdf,它检查哪些网格点(按x_err/y_err缩放并按x/y移位)在定义的限制范围内,将True/False网格乘以grid_pdf,由norm_pdf规范化。概率由规范化数组的总和给出。
我想把这个函数应用到元素层面,把这4个输入存储在4个不同的掩码数组中,这些掩码数组的形状相同,但可能有不同的掩码,然后用函数的输出创建一个新的相同形状的数组。有没有不使用for循环的方法?
谢谢你!
我目前的解决方案是:

mask1 = np.array([[False, True, False],[True, True, True],[True, False, False]])
mask2 = np.array([[True, True, True],[True, True, False],[False, False, True]])
# the only overlaps should be [0,1], [1,0] and [1,1]

x = np.ma.array(np.random.randn(*mask1.shape), mask=~mask1)
x_err = np.ma.array(np.abs(np.random.randn(*mask1.shape))*0.1, mask=~mask1)

y = np.ma.array(np.random.randn(*mask2.shape), mask=~mask2)
y_err = np.ma.array(np.abs(np.random.randn(*mask2.shape))*0.1, mask=~mask2)

# a combined mask to iterate through
all_mask = x+x_err+y+y_err

prob = np.zeros(mask1.shape)
prob = np.ma.masked_where(np.ma.getmask(all_mask), prob)

for i,xi in np.ma.ndenumerate(all_mask):
    prob[i] = cal_prob(xi, x_err[i], y[i], y_err[i])
kq0g1dla

kq0g1dla1#

使用掩码数组输入的np.vectorize测试:

In [180]: def foo(x):
     ...:     print(x)
     ...:     return 2*x
     ...:     

In [181]: np.vectorize(foo)(np.ma.masked_array([1,2,3],[True,False,True]))
1
1
2
3
Out[181]: 
masked_array(data=[--, 4, --],
             mask=[ True, False,  True],
       fill_value=999999)

In [182]: _.data
Out[182]: array([2, 4, 6])

相关问题