numpy 计算两个列表的所有欠条的有效方法

kb5ga3dv  于 2022-11-10  发布在  其他
关注(0)|答案(2)|浏览(135)

我有一个函数可以计算两个矩形/边界框的欠条。

def intersection_over_union(boxA, boxB):
    # determine the (x, y)-coordinates of the intersection rectangle
    xA = max(boxA[0], boxB[0])
    yA = max(boxA[1], boxB[1])
    xB = min(boxA[2], boxB[2])
    yB = min(boxA[3], boxB[3])

    # compute the area of intersection rectangle
    interArea = max(0, xB - xA + 1) * max(0, yB - yA + 1)

    # compute the area of both the prediction and ground-truth
    # rectangles
    boxAArea = (boxA[2] - boxA[0] + 1) * (boxA[3] - boxA[1] + 1)
    boxBArea = (boxB[2] - boxB[0] + 1) * (boxB[3] - boxB[1] + 1)

    # compute the intersection over union by taking the intersection
    # area and dividing it by the sum of prediction + ground-truth
    # areas - the interesection area
    iou = interArea / float(boxAArea + boxBArea - interArea)

    # return the intersection over union value
    return iou

现在我想用另一个列表的bbox计算一个列表的bbox的所有借条,也就是说,如果列表A包含4个bbox,列表B包含3个bbox,那么我想要一个包含所有可能的借条的4x3矩阵。
当然,我可以使用这样的两个循环来完成此操作

import numpy as np

n_i = len(bboxes_a)
n_j = len(bboxes_b)
iou_mat = np.empty((n_i, n_j))
for i in range(n_i):
    for j in range(n_j):
        iou_mat[i, j] = intersection_over_union(bboxes_a[i], bboxes_b[j])

但这种方法非常慢,特别是当列表变得非常大的时候。
我正在努力寻找一种更有效的方法。肯定有一种方法可以利用NumPy来消除循环,但我不明白。而且现在的复杂度是O(m*n)。有没有可能降低复杂性?

kxxlusnw

kxxlusnw1#

矢量化:

low = np.s_[...,:2]
high = np.s_[...,2:]

def iou(A,B):
    A,B = A.copy(),B.copy()
    A[high] += 1; B[high] += 1
    intrs = (np.maximum(0,np.minimum(A[high],B[high])
                        -np.maximum(A[low],B[low]))).prod(-1)
    return intrs / ((A[high]-A[low]).prod(-1)+(B[high]-B[low]).prod(-1)-intrs)

AB = iou(A[:,None],B[None])

复杂性:

由于您正在计算M x N值,因此不可能将复杂性降低到M x N以下,除非大多数值为零,并且矩阵的稀疏表示是可以接受的。
这可以通过对A和B的所有末端进行加法排序(分别对于x和y)来完成。这就是O(M+N)log(M+N))EDIT,因为坐标是整数线性复杂性在这里是可能的。编辑结束然后可用于预过滤A x B。过滤和计算非零的复杂度为O(M+N+非零数)。

pobjuy32

pobjuy322#

您可以在Python的itertools中使用product()来替换嵌套的for循环。我认为使用内置函数总是更好的。示例可以如下所示:

import numpy as np

l1 = np.random.randint(0, 10, (4, 4))
l2 = np.random.randint(0, 10, (3, 4))
print(f'l1:\n{l1}')
print(f'l2:\n{l2}')

from itertools import product

ious = np.array([intersection_over_union(box1, box2) for box1, box2 in product(l1, l2)]).reshape(len(l2), len(l1))
print(f'ious:\n{ious}')

此外,应将iou = interArea / float(boxAArea + boxBArea - interArea)改为iou = interArea / float(boxAArea + boxBArea - interArea + 1e-16),以避免divided by zero error.

相关问题