使用numpy进行按位数组操作

p5cysglq  于 2023-08-05  发布在  其他
关注(0)|答案(3)|浏览(87)

我想翻译下面的C++表达式

(i0?8:0) | (i1?4:0) | (i2?2:0) | (i3?1:0)

字符串
python,其中i0i1i2i3是bool的numpy.ndarray(通过i0 = x < y等表达式获得)。实现这一目标的最有效方法是什么?我天真的做法是

np.where(i0,8,0) | np.where(i1,4,0) | np.where(i2,2,0) |  np.where(i3,1,0)


但是有更快的方法吗?

gg0vcinb

gg0vcinb1#

您可以:

  • 使用view so将布尔值转换为1字节整数值,不需要任何拷贝;
  • 使用1字节整数计算操作(对于大型数组更快,因为它们在RAM和CPU缓存中占用的空间更少);
  • 用移位替换np.where,通常计算起来要快一些(多亏了SIMD指令);
  • 由于许多Numpy函数的第三个参数(out),可以执行就地计算,以避免创建无用的临时数组

下面是生成的代码:

def compute_faster(i0, i1, i2, i3):
    tmp1 = np.left_shift(i0.view(np.uint8), np.uint8(3))
    tmp2 = np.left_shift(i1.view(np.uint8), np.uint8(2))
    np.bitwise_or(tmp1, tmp2, tmp1)
    np.left_shift(i2.view(np.uint8), np.uint8(1), tmp2)
    np.bitwise_or(tmp1, tmp2, tmp1)
    np.bitwise_or(tmp1, i3.view(np.uint8), tmp1)
    return tmp1

字符串
我假设输入值是布尔类型的数组,因为x < y返回这样的数组。注意,输出是一个uint 8类型的数组(所以你应该关心溢出和有符号的值,或者你可以只转换它或int 32类型的数组以使其安全,代价是计算速度较慢)。
这当然是只使用Numpy的最佳解决方案。为了获得更快的性能,您当然需要使用本机编译代码。NumbaCython是一个很好的加速方法。以下是使用Numba的解决方案:

import numba as nb

# Assume the input arrays are contiguous. 
# Please use bool_[:] otherwise
@nb.njit('(bool_[::1], bool_[::1], bool_[::1], bool_[::1])')
def compute_fastest(i0, i1, i2, i3):
    out = np.empty(i0.size, dtype=np.uint8)
    ui0 = i0.view(np.uint8)
    ui1 = i1.view(np.uint8)
    ui2 = i2.view(np.uint8)
    ui3 = i3.view(np.uint8)
    for i in range(i0.size):
        # `+` is faster than `|` since `|` is not a bitwise operator in Python and Numba keep its semantic.
        # This enable the compiler to generate fast SIMD instructions
        out[i] = (ui0[i] << np.uint8(3)) + (ui1[i] << np.uint8(2)) + (ui2[i] << np.uint8(1)) + ui3[i]
    return out

性能结果

以下是在我的i5- 9600 KF处理器上使用Numpy 1.24.3的大小为4096的数组的性能结果:

reference code:    27.2 µs
compute_faster:    17.7 µs
compute_fastest:    1.5 µs


Numba代码当然非常接近于非常短的数组的最佳值。对于非常短的代码,Cython应该更好,因为它可能会降低开销。对于非常大的数组,可以使用多线程(使用parallel=True标志和nb.prange而不是range)。

tp5buhyn

tp5buhyn2#

假设i0/i1/i2/i3的布尔数组可以用途:

out = 8*i0 + 4*i1 + 2*i2 + i3

字符串

2eafrhcq

2eafrhcq3#

您可以使用按位运算符:

out = (i0<<3) | (i1<<2) | (i2<<1) | (i3<<0)

字符串

相关问题