如何在numpy中过滤3维数组

zwghvu4y  于 2023-08-05  发布在  其他
关注(0)|答案(3)|浏览(73)

我知道这是numpy中最简单的任务之一,但我不明白。我有一个数组:

import numpy as np

u1 = np.array([[[-2.04853845, -1.47283101,  1.        ],
                [-2.05009646, -1.46780913,  1.        ],
                [-2.05165539, -1.46278426,  1.        ]],
               [[-4.94165412,  7.8524744 ,  1.        ],
                [ 0.94540647,   0.86456925,  1.       ],
                [-4.94916228,  7.87667525,  1.        ]]])

字符串
u1是矩阵的汇总。矩阵的真实的大小为:

u1.shape
(813, 1200, 3)


我想得到最后一个数组(最后一个维度)的第一个和第二个元素在0和1之间的所有值,所以我得到一个包含布尔值的数组,这样我就可以索引另一个具有相同形状的数组。
我敢打赌这是一个重复的问题,但我没有找到它。谢啦,谢啦

anhgbhbe

anhgbhbe1#

您正在寻找以下内容:

import numpy as np

u1 = np.array([[[-2.04853845, -1.47283101,  1.],
                [-2.05009646, -1.46780913,  1.],
                [-2.05165539, -1.46278426,  1.]],
               [[-4.94165412,  7.8524744,   1.],
                [ 0.94540647,  0.86456925,  1.],
                [-4.94916228,  7.87667525,  1.]]])

np.logical_and(u1[:,:,0:2]>0, u1[:,:,0:2]<1)
array([[[False, False],
        [False, False],
        [False, False]],

       [[False, False],
        [ True,  True],
        [False, False]]])

字符串

ppcbkaq5

ppcbkaq52#

提取元素并与您的值进行比较:

import numpy as np

u1 = np.array([[[-2.04853845, -1.47283101,  1.        ],
                [-2.05009646, -1.46780913,  1.        ],
                [-2.05165539, -1.46278426,  1.        ]],
               [[-4.94165412,  7.8524744 ,  1.        ],
                [ 0.94540647,   0.86456925,  1.       ],
                [-4.94916228,  7.87667525,  1.        ]]])

first = (u1[:, :, 0] > 0) * (u1[:, :, 0] < 1)
second = (u1[:, :, 1] > 0) * (u1[:, :, 1] < 1)

condition = first * second
# Condition is:
# [[False False False]
# [False  True False]]

字符串
然后提取您感兴趣的行:

values = u1[condition]

print(values)
# Prints:
# [[0.94540647 0.86456925 1.        ]]

l7wslrjt

l7wslrjt3#

定义以下函数,检查行的前两个元素:

def myfunc(row):
    return np.logical_and(row[0:2] >= 0, row[0:2] <= 1).all()

字符串
然后将其应用于最后一个维度沿着的每个一维切片(我们称之为 * 行 *):

np.apply_along_axis( myfunc, axis=-1, arr=u1 )


示例数据的结果为:

array([[False, False, False],
       [False,  True, False]])

相关问题