检查numpy数组是否全为非负的最佳方法

myss37ts  于 2023-03-02  发布在  其他
关注(0)|答案(2)|浏览(170)

这是可行的,但不是算法上的最优,因为我不需要在函数解析数组时存储min值:

def is_non_negative(m):
    return np.min(m) >= 0

编辑:根据数据的不同,一个最优函数确实可以保存很多,因为它会在第一次遇到负值时终止。如果只期望一个负值,平均时间将减少两倍。然而,在numpy库之外构建最优算法将花费巨大的成本(Python代码与C++代码)。

ldioqlga

ldioqlga1#

一个纯粹的Numpy解决方案是使用基于块的策略:

def is_non_negative(m):
    chunkSize = max(min(65536, m.size/8), 4096) # Auto-tunning
    for i in range(0, m.size, chunkSize):
        if np.min(m[i:i+chunkSize]) < 0:
            return False
    return True

这种解决方案只有在数组很大的情况下才是有效的,并且块足够大以使Numpy调用开销很小,并且块足够小以将全局数组拆分成许多部分(以便从早期的剪切中获益)。块大小需要相当大,以便平衡小数组上相对较大的np.min开销。
下面是一个Numba解决方案:

import numba as nb

# Eagerly compiled funciton for some mainstream data-types.
@nb.njit(['(float32[::1],)', '(float64[::1],)', '(int_[::1],)'])
def is_non_negative_nb(m):
    for e in m:
        if e < 0:
            return False
    return True

事实证明,这比在我的机器上使用np.min要快,尽管LLVM-Lite(Numba的JIT)没有很好地自动矢量化代码(即不使用SIMD指令)。
对于更快的代码,您需要使用C/C++代码和基于块的SIMD友好代码,如果编译器不能生成高效代码(不幸的是,这种情况相当常见),则可能使用SIMD内部函数。

ghhaqwfi

ghhaqwfi2#

一种可能的解决方案是使用C中的函数:

#include <stdio.h>
#include <stdlib.h>

int is_negative(double* data, int num_elems) {
    for (int i = 0; i < num_elems; i++) {
        if (data[i] < 0) {
            return 1;
        }
    }
    return 0;
}

编译:

gcc -c -fPIC is_negative.c -o is_negative.o

并链接:

gcc -shared is_negative.o -o libis_negative.so

然后,在Python中:

import numpy as np
import ctypes

lib = ctypes.cdll.LoadLibrary('/tmp/libis_negative.so')

a = np.array([1.0, 2.0, -3.0, 4.0])
num_elems = a.size

lib.is_negative.restype = ctypes.c_int
lib.is_negative.argtypes = [
    np.ctypeslib.ndpointer(dtype=np.float64),
    ctypes.c_int,
]

result = lib.is_negative(a, num_elems)

if result:
    print("It has negative elements")
else:
    print("It does not have negative elements")

相关问题