numpy 计算坐标下的点

ttp71kqs  于 2023-10-19  发布在  其他
关注(0)|答案(1)|浏览(107)

我有以下问题,我目前使用的解决方案太慢了。

**示例:**形状为(B,2)的numpy数组b,形状为(X)的排序numpy数组x,形状为(Y)的排序numpy数组y

注意x = np.unique(b[:,0])y = np.unique(b[:,1]),如果这对问题有区别的话。

**任务:**构建(X,Y)-array H,使得H[i,j]b中第一个条目小于x[i]且第二个条目小于y[j]的行数。

下面的示例代码解决了这个问题:

import numpy as np
b = np.random.random((2000,2))
x = np.unique(b[:,0])
y = np.unique(b[:,1])
H = np.count_nonzero(
    np.logical_and(
        b[:,0,None,None] <= x[None,:,None], 
        b[:,1,None,None] <= y[None,None,:]
    ), 
    axis=0
)

但是如果b以及因此xy具有几千个条目,则这变得相当慢。
我如何才能更有效地做到这一点?

wgeznvg7

wgeznvg71#

算法

您当前的实现很好地进行了向量化,但与可以完成的工作相比,它计算的工作量太大了。主要的问题是这个实现的算法复杂度是**O(n**3)
这个问题可以用
O(n**2)算法解决。这是最佳复杂度**,因为H至少需要填充。关键是对by进行排序,以便在计算H行时有效地计算条目的数量。排序后的数组需要通过x进行过滤,这样就不用关心热循环中的x值。注意,这种方法假设xy是排序的,因为np.unique对值进行排序。

实现

为了提高实现的性能,可以使用Numba多线程。注意内存布局对算法缓存友好也很重要。
下面是最终的实现:

import numba as nb

@nb.njit('(float64[::1], float64[::1], float64[:,::1])', parallel=True)
def compute(x, y, b):
    H = np.empty((x.size, y.size), dtype=np.int32)

    # Improve the memory layout
    b_x = b[:,0]
    b_y = b[:,1]

    # Sort both b_x and b_y pairs in a way b_y is sorted
    sorted_idx = np.argsort(b_y)
    sorted_b_x = b_x[sorted_idx]
    sorted_b_y = b_y[sorted_idx]

    for i in nb.prange(x.size):
        # Filter by x
        sorted_filtered_b_y = sorted_b_y[sorted_b_x <= x[i]]

        size = sorted_filtered_b_y.size
        count = 0
        idx = 0

        for j in range(y.size):
            cur_y = y[j]
            while idx < size and sorted_filtered_b_y[idx] <= cur_y:
                count += 1
                idx += 1
            H[i, j] = count

    return H

x = np.unique(b[:,0])
y = np.unique(b[:,1])
compute(x, y, b)

性能结果

以下是在我的i5- 9600 KF机器上使用所提供的输入的性能结果:

Naive implementation:    9874.74 ms
This implementation:        3.85 ms

因此,这种实现方式的速度是2565倍

相关问题