如何使用掩码来限制两个numpy数组之间的广播操作?

de90aj5v  于 2023-10-19  发布在  其他
关注(0)|答案(1)|浏览(84)

我有一个这样的数组:

data = np.array([
    [[10, 10, 10],
     [10, 10, 10],
     [10, 10, 10]],

    [[20, 20, 20],
     [20, 20, 20],
     [20, 20, 20]],

    [[30, 30, 30],
     [30, 30, 30],
     [30, 30, 30]],
], dtype=np.float64)

另一个用来除数值,就像这样:

divide_by = np.array([
    [[10, 10, 1]],
    [[1, 10, 10]],
    [[1, 1, 1]],
], dtype=np.float64)

我想将data数组的每一行(轴0)除以divide_by数组中的值(有点像邮票),但仅限于给定掩码(作为data的形状)被设置为True的位置。
所以第一部分我可以通过以下方式实现:

divide_by = divide_by.reshape(divide_by.shape[0], divide_by.shape[2])

data /= divide_by

print(data)

其产生:

[[[ 1.  1. 10.]
  [10.  1.  1.]
  [10. 10. 10.]]

 [[ 2.  2. 20.]
  [20.  2.  2.]
  [20. 20. 20.]]

 [[ 3.  3. 30.]
  [30.  3.  3.]
  [30. 30. 30.]]]

请注意,data数组的每一行都被divide_by中的内容所分割,就像在上面盖了一个图章一样。太好了
现在我想做同样的事情,但只在掩码设置为true的地方应用除法:

mask = np.array([
    [[False, True, False],
     [False, False, False],
     [True, False, False]],

    [[True, True, True],
     [False, False, True],
     [False, False, False]],

    [[True, False, False],
     [False, False, False],
     [False, False, False]],
])

因此,期望输出为:

[[[10.  1. 10.]
  [10. 10.  1.]
  [10. 10. 10.]]

 [[ 2.  2. 20.]
  [20. 20.  2.]
  [20. 20. 20.]]

 [[ 3. 30. 30.]
  [30. 30. 30.]
  [30. 30. 30.]]]

掩码定义了一个要除的地方的子集,
但如果我做了:

data[mask] /= divide_by

而不是

data /= divide_by

我得到:

ValueError: operands could not be broadcast together with shapes (7,) (3,3) (7,)

在这种特殊情况下,我如何使用这个面具?

icomxhvb

icomxhvb1#

可以使用np.where(mask, data / divide_by[None, :, 0], data)

相关问题