numpy 这个递归函数能否转化为具有类似性能的迭代函数?

lfapxunr  于 12个月前  发布在  其他
关注(0)|答案(2)|浏览(84)

我正在用python写一个函数,使用numba来标记2D或3D数组中的对象,这意味着输入数组中具有相同值的所有正交连接的单元格将在输出数组中被赋予从1到N的唯一标签,其中N是正交连接的组的数量。它与scipy.ndimage.label等函数以及scikit-image等库中的类似函数非常相似,但这些函数标记了所有正交连接的非零细胞组,因此它会合并具有不同值的连接组,这是我不想要的。例如,给定以下输入:

[0 0 7 7 0 0
 0 0 7 0 0 0
 0 0 0 0 0 7
 0 6 6 0 0 7
 0 0 4 4 0 0]

scipy函数将返回

[0 0 1 1 0 0
 0 0 1 0 0 0
 0 0 0 0 0 3
 0 2 2 0 0 3
 0 0 2 2 0 0]

请注意,6s和4s合并到标签2中。我希望将它们标记为单独的组,例如:

[0 0 1 1 0 0
 0 0 1 0 0 0
 0 0 0 0 0 4
 0 2 2 0 0 4
 0 0 3 3 0 0]

asked this about a year ago和一直在使用的解决方案,在接受的答案,但我正在优化我的代码的运行时,并重新审视这个问题。
对于我通常使用的数据大小,链接的解决方案需要大约1 m30 s运行。我写了下面的递归算法,它像普通的python一样运行大约需要30秒,而numba的JIT运行时间为1- 2秒(旁注,我讨厌相邻的函数,任何能让它不那么混乱同时仍然与numba兼容的提示都会受到欢迎):

@numba.njit
def adjacent(idx, shape):
    coords = []
    if len(shape) > 2:
        if idx[0] < shape[0] - 1:
            coords.append((idx[0] + 1, idx[1], idx[2]))
        if idx[0] > 0:
            coords.append((idx[0] - 1, idx[1], idx[2]))
        if idx[1] < shape[1] - 1:
            coords.append((idx[0], idx[1] + 1, idx[2]))
        if idx[1] > 0:
            coords.append((idx[0], idx[1] - 1, idx[2]))
        if idx[2] < shape[2] - 1:
            coords.append((idx[0], idx[1], idx[2] + 1))
        if idx[2] > 0:
            coords.append((idx[0], idx[1], idx[2] - 1))
    else:
        if idx[0] < shape[0] - 1:
            coords.append((idx[0] + 1, idx[1]))
        if idx[0] > 0:
            coords.append((idx[0] - 1, idx[1]))
        if idx[1] < shape[1] - 1:
            coords.append((idx[0], idx[1] + 1))
        if idx[1] > 0:
            coords.append((idx[0], idx[1] - 1))
    return coords

@numba.njit
def apply_label(labels, decoded_image, current_label, idx):
    labels[idx] = current_label
    for aidx in adjacent(idx, labels.shape):
        if decoded_image[aidx] == decoded_image[idx] and labels[aidx] == 0:
            apply_label(labels, decoded_image, current_label, aidx)

@numba.njit
def label_image(decoded_image):
    labels = np.zeros_like(decoded_image, dtype=np.uint32)
    current_label = 0
    for idx in zip(*np.where(decoded_image >= 0)):
        if labels[idx] == 0:
            current_label += 1
            apply_label(labels, decoded_image, current_label, idx)
    return labels, current_label

这对某些数据有效,但对其他数据崩溃,我发现问题是当有非常大的对象要标记时,达到了递归限制。我试着重写label_image不使用递归,但现在用numba需要10秒。仍然是一个巨大的改进,从我开始的地方,但它似乎应该有可能获得相同的性能作为递归版本。以下是我的迭代版本:

@numba.njit
def label_image(decoded_image):
    labels = np.zeros_like(decoded_image, dtype=np.uint32)
    current_label = 0
    for idx in zip(*np.where(decoded_image >= 0)):
        if labels[idx] == 0:
            current_label += 1
            idxs = [idx]
            while idxs:
                cidx = idxs.pop()
                if labels[cidx] == 0:
                    labels[cidx] = current_label
                    for aidx in adjacent(cidx, labels.shape):
                        if labels[aidx] == 0 and decoded_image[aidx] == decoded_image[idx]:
                            idxs.append(aidx)
    return labels, current_label

有什么办法可以改善这一点吗?

yk9xbfzb

yk9xbfzb1#

这个递归函数能否转化为具有类似性能的迭代函数?
把它变成一个迭代函数很简单,因为它只是一个简单的深度优先搜索(你也可以使用宽度优先搜索,使用队列而不是堆栈,两者都可以工作)。只需使用堆栈来跟踪要访问的节点。这里有一个通用的解决方案,适用于任何数量的维度:

def label_image(decoded_image):
    shape = decoded_image.shape
    labels = np.zeros_like(decoded_image, dtype=np.uint32)
    current_label = 0
    for idx in zip(*np.where(decoded_image > 0)):
        if labels[idx] == 0:
            current_label += 1
            stack = [idx]
            while stack:
                top = stack.pop()
                labels[top] = current_label
                for i in range(0, len(shape)):
                    if top[i] > 0:
                        neighbor = list(top)
                        neighbor[i] -= 1
                        neighbor = tuple(neighbor)
                        if decoded_image[neighbor] == decoded_image[idx] and labels[neighbor] == 0:
                            stack.append(neighbor)
                    if top[i] < shape[i] - 1:
                        neighbor = list(top)
                        neighbor[i] += 1
                        neighbor = tuple(neighbor)
                        if decoded_image[neighbor] == decoded_image[idx] and labels[neighbor] == 0:
                            stack.append(neighbor)
    return labels

从元组的第i个组件中增加或减去一是很尴尬的(我在这里介绍一个临时列表),numba不接受它(类型错误)。一个简单的解决方案是显式地编写2d和3d版本,这可能会极大地提高性能:

@numba.njit
def label_image_2d(decoded_image):
    w, h = decoded_image.shape
    labels = np.zeros_like(decoded_image, dtype=np.uint32)
    current_label = 0
    for idx in zip(*np.where(decoded_image > 0)):
        if labels[idx] == 0:
            current_label += 1
            stack = [idx]
            while stack:
                x, y = stack.pop()
                if decoded_image[x, y] != decoded_image[idx] or labels[x, y] != 0:
                    continue # already visited or not part of this group
                labels[x, y] = current_label
                if x > 0: stack.append((x-1, y))
                if x+1 < w: stack.append((x+1, y))
                if y > 0: stack.append((x, y-1))
                if y+1 < h: stack.append((x, y+1))
    return labels

@numba.njit
def label_image_3d(decoded_image):
    w, h, l = decoded_image.shape
    labels = np.zeros_like(decoded_image, dtype=np.uint32)
    current_label = 0
    for idx in zip(*np.where(decoded_image > 0)):
        if labels[idx] == 0:
            current_label += 1
            stack = [idx]
            while stack:
                x, y, z = stack.pop()
                if decoded_image[x, y, z] != decoded_image[idx] or labels[x, y, z] != 0:
                    continue # already visited or not part of this group
                labels[x, y, z] = current_label
                if x > 0: stack.append((x-1, y, z))
                if x+1 < w: stack.append((x+1, y, z))
                if y > 0: stack.append((x, y-1, z))
                if y+1 < h: stack.append((x, y+1, z))
                if z > 0: stack.append((x, y, z-1))
                if z+1 < l: stack.append((x, y, z+1))
    return labels

def label_image(decoded_image):
    dim = len(decoded_image.shape)
    if dim == 2:
        return label_image_2d(decoded_image)
    assert dim == 3
    return label_image_3d(decoded_image)

还请注意,迭代解决方案不受堆栈限制:np.full((100,100,100), 1)在迭代解决方案中工作正常,但在递归解决方案中失败(如果使用numba,则会出现segfaults)。
做一个非常基本的基准测试,

for i in range(1, 10000):
    label_image(np.full((20,20,20), i))

(many迭代以最大限度地减少JIT的影响,也可以进行一些预热运行,然后开始测量时间或类似操作)
迭代解决方案似乎快了好几倍(在我的机器上大约快了5倍,见下文)。你可能会优化递归解决方案,并使其达到相当的速度。通过避免临时coords列表或通过将np.where改变为> 0
我不知道numba如何优化压缩的np.where。为了进一步优化,您可以考虑(和基准测试)使用显式嵌套for x in range(0, w): for y in range(0, h):循环。
为了与Nick提出的合并策略保持竞争力,我进一步优化了这个策略,选择了一些容易实现的目标:

  • zip转换为continue而不是np.where的显式循环。
  • decoded_image[idx]存储在一个局部变量中(理想情况下应该没关系,但也没什么坏处)。
  • 重用堆栈。这可以防止不必要的(重新)分配和GC压力。还可以考虑为电池堆提供初始容量(分别为w*hw*h*l)。
@numba.njit
def label_image_2d(decoded_image):
    w, h = decoded_image.shape
    labels = np.zeros_like(decoded_image, dtype=np.uint32)
    current_label = 0
    stack = []
    for sx in range(0, w):
        for sy in range(0, h):
            start = (sx, sy)
            image_label = decoded_image[start]
            if image_label <= 0 or labels[start] != 0:
                continue
            current_label += 1
            stack.append(start)
            while stack:
                x, y = stack.pop()
                if decoded_image[x, y] != image_label or labels[x, y] != 0:
                    continue # already visited or not part of this group
                labels[x, y] = current_label
                if x > 0: stack.append((x-1, y))
                if x+1 < w: stack.append((x+1, y))
                if y > 0: stack.append((x, y-1))
                if y+1 < h: stack.append((x, y+1))
    return labels

@numba.njit
def label_image_3d(decoded_image):
    w, h, l = decoded_image.shape
    labels = np.zeros_like(decoded_image, dtype=np.uint32)
    current_label = 0
    stack = []
    for sx in range(0, w):
        for sy in range(0, h):
            for sz in range(0, l):
                start = (sx, sy, sz)
                image_label = decoded_image[start]
                if image_label <= 0 or labels[start] != 0:
                    continue
                current_label += 1
                stack.append(start)
                while stack:
                    x, y, z = stack.pop()
                    if decoded_image[x, y, z] != image_label or labels[x, y, z] != 0:
                        continue # already visited or not part of this group
                    labels[x, y, z] = current_label
                    if x > 0: stack.append((x-1, y, z))
                    if x+1 < w: stack.append((x+1, y, z))
                    if y > 0: stack.append((x, y-1, z))
                    if y+1 < h: stack.append((x, y+1, z))
                    if z > 0: stack.append((x, y, z-1))
                    if z+1 < l: stack.append((x, y, z+1))
    return labels

然后,我拼凑了一个基准来比较这四种方法(原始递归,旧迭代,新迭代,基于合并),将它们放在四个不同的模块中:

import numpy as np
import timeit

import rec
import iter_old
import iter_new
import merge

shape = (100, 100, 100)
n = 20
for module in [rec, iter_old, iter_new, merge]:
    print(module)

    label_image = module.label_image
    # Trigger compilation of 2d & 3d functions
    label_image(np.zeros((1, 1)))
    label_image(np.zeros((1, 1, 1)))

    i = 0
    def test_full():
        global i
        i += 1
        label_image(np.full(shape, i))
    print("single group:", timeit.timeit(test_full, number=n))
    print("random (few groups):", timeit.timeit(
        lambda: label_image(np.random.randint(low = 1, high = 10, size = shape)),
        number=n))
    print("random (many groups):", timeit.timeit(
        lambda: label_image(np.random.randint(low = 1, high = 400, size = shape)),
        number=n))
    print("only groups:", timeit.timeit(
        lambda: label_image(np.arange(np.prod(shape)).reshape(shape)),
        number=n))

这将输出类似于

<module 'rec' from '...'>
single group: 32.39212468900041
random (few groups): 14.648884047001047
random (many groups): 13.304533919001187
only groups: 13.513677138000276
<module 'iter_old' from '...'>
single group: 10.287227957000141
random (few groups): 17.37535468200076
random (many groups): 14.506630064999626
only groups: 13.132202609998785
<module 'iter_new' from '...'>
single group: 7.388022166000155
random (few groups): 11.585243002000425
random (many groups): 9.560101995000878
only groups: 8.693653742000606
<module 'merge' from '...'>
single group: 14.657021331999204
random (few groups): 14.146574055999736
random (many groups): 13.412314713001251
only groups: 12.642367746000673

在我看来,改进的迭代方法可能更好。请注意,原始的基本基准测试似乎是递归变体的最差情况。总的来说,差别并不大。
测试的数组非常小(20³)。如果我用一个更大的数组(100³)和一个更小的n(20)进行测试,我大致得到以下结果(rec被省略,因为由于堆栈限制,它会segfault):

<module 'iter_old' from '...'>
single group: 3.5357716739999887
random (few groups): 4.931695729999774
random (many groups): 3.4671142009992764
only groups: 3.3023930709987326
<module 'iter_new' from '...'>
single group: 2.45903080700009
random (few groups): 2.907660342001691
random (many groups): 2.309699692999857
only groups: 2.052835552000033
<module 'merge' from '...'>
single group: 3.7620838259990705
random (few groups): 3.3524249689999124
random (many groups): 3.126650959999097
only groups: 2.9456547739991947

迭代方法似乎仍然更有效。

hgc7kmma

hgc7kmma2#

这是我的尝试。我的想法是应用逐行连接组件算法。这个想法是,而不是试图洪水填充来识别组件的所有成员,你开始在每一行标记,并注意每次你到达一个矛盾。
然后,你把你的矛盾列表,并将它们合并到同一个类中。
This video解释了算法。
这样做的缺点是需要对输入数据进行两次处理。然而,它的优点是它具有非常好的缓存一致性。换句话说,它按顺序访问数组中的所有数据。
这与scipy.ndimage.label()内部遵循的算法相同,但我们可以更快,因为我们不需要为每个类调用一次。
下面是我使用的代码。

@numba.njit
def adjacent(idx, shape):
    coords = []
    if len(shape) > 2:
        if idx[0] > 0:
            coords.append((idx[0] - 1, idx[1], idx[2]))
        if idx[1] > 0:
            coords.append((idx[0], idx[1] - 1, idx[2]))
        if idx[2] > 0:
            coords.append((idx[0], idx[1], idx[2] - 1))
    else:
        if idx[0] > 0:
            coords.append((idx[0] - 1, idx[1]))
        if idx[1] > 0:
            coords.append((idx[0], idx[1] - 1))
    return coords

@numba.njit
def merge_classes(labels, mergetable):
    for idx in np.ndindex(labels.shape):
        class_num = labels[idx]
        if class_num < len(mergetable):
            merge_target = mergetable[class_num]
            if merge_target != -1:
                labels[idx] = merge_target

@numba.njit
def add_to_merge_table(mergetable, class1, class2):
    # identify smallest element
    lo_class = min(class1, class2)
    hi_class = max(class1, class2)
    # Does the merge table require expansion?
    while len(mergetable) <= hi_class:
        new_mergetable = np.zeros(len(mergetable) * 2, dtype=np.int32)
        new_mergetable[:] = -1
        new_mergetable[:len(mergetable)] = mergetable
        mergetable = new_mergetable
    while mergetable[lo_class] != -1:
        lo_class = mergetable[lo_class]
    mergetable[hi_class] = lo_class
    return mergetable
                    

@numba.njit
def label_image(decoded_image):
    labels = np.zeros_like(decoded_image, dtype=np.uint32)
    current_label = 0
    mergetable = np.zeros(8, dtype=np.int32)
    mergetable[:] = -1
    for idx in np.ndindex(labels.shape):
        decoded_image_idx = decoded_image[idx]
        labels_idx = labels[idx]
        for aidx in adjacent(idx, labels.shape):
            labels_aidx = labels[aidx]
            if labels_aidx != 0 and decoded_image[aidx] == decoded_image_idx:
                # Already have class for neighboring pixel
                if labels_idx == 0:
                    # This pixel has no class, copy neighbor
                    labels_idx = labels[idx] = labels_aidx
                elif labels_aidx != labels_idx:
                    # This pixel has a contradictory class
                    # Assign minimum and add to merge table
                    mergetable = add_to_merge_table(mergetable, labels_aidx, labels_idx)
                    labels_idx = labels[idx] = min(labels_idx, labels_aidx)
        if labels_idx == 0:
            current_label += 1
            labels[idx] = current_label
    merge_classes(labels, mergetable)
    return labels, current_label

它比我在随机数组上尝试的迭代版本快了大约30%。

相关问题