你能解释一下这段代码的输出吗?以及如何在numpy中使用where函数?

lb3vh1jj  于 2022-12-18  发布在  其他
关注(0)|答案(2)|浏览(145)
a = np.array([[1,2],[3,4]])
np.where(a<4)

答:

array([0,0,1]), array([0,1,0])

请解释输出:
答:

array([0,0,1]),array([0,1,0])

5uzkadbs

5uzkadbs1#

numpy.where为您提供真实值的索引。
我希望这个分解能帮助你理解其中的逻辑:

a = np. array([[1,2],[3,4]])
#         0  1
# array([[1, 2],   # 0
#        [3, 4]])  # 1

a<4
#            0      1
# array([[ True,  True],    0
#        [ True, False]])   1

# flat version
# row:          0     0       1      1
# col:          0     1       0      1
# # array([[ True,  True], [ True, False]])

# keep only the True
# row:  [0, 0, 1]
# col:  [0, 1, 0]

np.where(a<4)
# (array([0, 0, 1]), array([0, 1, 0]))
nle07wnf

nle07wnf2#

np.where返回一个包含(在本例中)2个元素的元组。
第一元素是满足条件的元素的行索引。
第二元素是满足条件的元素的列索引。
要进行检查,请保存结果,例如:

ind = np.where(a < 4)

现在,当您运行a[ind]时,您将得到一个由满足以下条件的元素填充的一维数组,即:

array([1, 2, 3])

如果源数组具有更多维度,则生成的元组将具有更多组件。

相关问题