如何过滤numpy数组的行

xt0899hw  于 2023-04-06  发布在  其他
关注(0)|答案(2)|浏览(153)

我希望对numpy数组的每一行应用一个函数。如果此函数的计算结果为True,我将保留该行,否则我将丢弃它。例如,我的函数可能是:

def f(row):
    if sum(row)>10: return True
    else: return False

我想知道是否有类似的东西:

np.apply_over_axes()

它对numpy数组的每一行应用一个函数并返回结果。我希望是这样的:

np.filter_over_axes()

它会将一个函数应用于numpy数组的每一行,并且只返回函数返回True的行。有类似这样的东西吗?或者我应该使用for循环?

gev0vcfq

gev0vcfq1#

理想情况下,你应该能够实现一个向量化版本的函数,并使用它来做布尔索引。对于绝大多数问题来说,这是正确的解决方案。Numpy提供了相当多的函数,可以在各种轴上操作,以及所有的基本操作和比较,所以大多数有用的条件应该是可向量化的。

import numpy as np

x = np.random.randn(20, 3)
x_new = x[np.sum(x, axis=1) > .5]

如果你绝对确定你不能做到以上这些,我建议你使用一个列表解析(或者np.apply_along_axis)来创建一个布尔值数组来索引。

def myfunc(row):
    return sum(row) > .5

bool_arr = np.array([myfunc(row) for row in x])
x_new = x[bool_arr]

这将以一种相对干净的方式完成工作,但会比矢量化版本慢得多。例如:

x = np.random.randn(5000, 200)

%timeit x[np.sum(x, axis=1) > .5]
# 100 loops, best of 3: 5.71 ms per loop

%timeit x[np.array([myfunc(row) for row in x])]
# 1 loops, best of 3: 217 ms per loop
wko9yo5t

wko9yo5t2#

正如@Roger Fan提到的,应用一个函数行方式 * 应该 * 真的以向量化的方式在整个数组上完成。过滤的规范方法是构造一个布尔掩码并将其应用于数组。也就是说,如果碰巧函数太复杂以至于无法向量化,将数组转换为Python列表(特别是如果它使用sum()等Python函数)并对其应用该函数会更好/更快。

msk = arr.sum(axis=1)>10                # best way to create a boolean mask

msk = [f(row) for row in arr.tolist()]  # second best way
#                            ^^^^^^^^   <---- convert to list

filtered_arr = arr[msk]                 # filtered via boolean indexing
工作示例和性能测试

从下面的timeit测试中可以看出,循环遍历列表(arr.tolist())比循环遍历numpy数组(arr)要快得多,部分原因是在函数f()中调用的是Python的sum()而不是np.sum()。也就是说,向量化方法比这两种方法都快得多。

def f(row):
    if sum(row)>10: return True
    else: return False
    
arr = np.random.rand(10000, 200)

%timeit arr[[f(row) for row in arr]]
# 260 ms ± 14 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

%timeit arr[[f(row) for row in arr.tolist()]]
# 114 ms ± 4.22 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

%timeit arr[arr.sum(axis=1)>10]
# 10.8 ms ± 2.03 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

相关问题