我正在用python写一个函数,使用numba来标记2D或3D数组中的对象,这意味着输入数组中具有相同值的所有正交连接的单元格将在输出数组中被赋予从1到N的唯一标签,其中N是正交连接的组的数量。它与scipy.ndimage.label
等函数以及scikit-image等库中的类似函数非常相似,但这些函数标记了所有正交连接的非零细胞组,因此它会合并具有不同值的连接组,这是我不想要的。例如,给定以下输入:
[0 0 7 7 0 0
0 0 7 0 0 0
0 0 0 0 0 7
0 6 6 0 0 7
0 0 4 4 0 0]
scipy函数将返回
[0 0 1 1 0 0
0 0 1 0 0 0
0 0 0 0 0 3
0 2 2 0 0 3
0 0 2 2 0 0]
请注意,6s和4s合并到标签2
中。我希望将它们标记为单独的组,例如:
[0 0 1 1 0 0
0 0 1 0 0 0
0 0 0 0 0 4
0 2 2 0 0 4
0 0 3 3 0 0]
我asked this about a year ago和一直在使用的解决方案,在接受的答案,但我正在优化我的代码的运行时,并重新审视这个问题。
对于我通常使用的数据大小,链接的解决方案需要大约1 m30 s运行。我写了下面的递归算法,它像普通的python一样运行大约需要30秒,而numba的JIT运行时间为1- 2秒(旁注,我讨厌相邻的函数,任何能让它不那么混乱同时仍然与numba兼容的提示都会受到欢迎):
@numba.njit
def adjacent(idx, shape):
coords = []
if len(shape) > 2:
if idx[0] < shape[0] - 1:
coords.append((idx[0] + 1, idx[1], idx[2]))
if idx[0] > 0:
coords.append((idx[0] - 1, idx[1], idx[2]))
if idx[1] < shape[1] - 1:
coords.append((idx[0], idx[1] + 1, idx[2]))
if idx[1] > 0:
coords.append((idx[0], idx[1] - 1, idx[2]))
if idx[2] < shape[2] - 1:
coords.append((idx[0], idx[1], idx[2] + 1))
if idx[2] > 0:
coords.append((idx[0], idx[1], idx[2] - 1))
else:
if idx[0] < shape[0] - 1:
coords.append((idx[0] + 1, idx[1]))
if idx[0] > 0:
coords.append((idx[0] - 1, idx[1]))
if idx[1] < shape[1] - 1:
coords.append((idx[0], idx[1] + 1))
if idx[1] > 0:
coords.append((idx[0], idx[1] - 1))
return coords
@numba.njit
def apply_label(labels, decoded_image, current_label, idx):
labels[idx] = current_label
for aidx in adjacent(idx, labels.shape):
if decoded_image[aidx] == decoded_image[idx] and labels[aidx] == 0:
apply_label(labels, decoded_image, current_label, aidx)
@numba.njit
def label_image(decoded_image):
labels = np.zeros_like(decoded_image, dtype=np.uint32)
current_label = 0
for idx in zip(*np.where(decoded_image >= 0)):
if labels[idx] == 0:
current_label += 1
apply_label(labels, decoded_image, current_label, idx)
return labels, current_label
这对某些数据有效,但对其他数据崩溃,我发现问题是当有非常大的对象要标记时,达到了递归限制。我试着重写label_image
不使用递归,但现在用numba需要10秒。仍然是一个巨大的改进,从我开始的地方,但它似乎应该有可能获得相同的性能作为递归版本。以下是我的迭代版本:
@numba.njit
def label_image(decoded_image):
labels = np.zeros_like(decoded_image, dtype=np.uint32)
current_label = 0
for idx in zip(*np.where(decoded_image >= 0)):
if labels[idx] == 0:
current_label += 1
idxs = [idx]
while idxs:
cidx = idxs.pop()
if labels[cidx] == 0:
labels[cidx] = current_label
for aidx in adjacent(cidx, labels.shape):
if labels[aidx] == 0 and decoded_image[aidx] == decoded_image[idx]:
idxs.append(aidx)
return labels, current_label
有什么办法可以改善这一点吗?
2条答案
按热度按时间yk9xbfzb1#
这个递归函数能否转化为具有类似性能的迭代函数?
把它变成一个迭代函数很简单,因为它只是一个简单的深度优先搜索(你也可以使用宽度优先搜索,使用队列而不是堆栈,两者都可以工作)。只需使用堆栈来跟踪要访问的节点。这里有一个通用的解决方案,适用于任何数量的维度:
从元组的第i个组件中增加或减去一是很尴尬的(我在这里介绍一个临时列表),numba不接受它(类型错误)。一个简单的解决方案是显式地编写2d和3d版本,这可能会极大地提高性能:
还请注意,迭代解决方案不受堆栈限制:
np.full((100,100,100), 1)
在迭代解决方案中工作正常,但在递归解决方案中失败(如果使用numba,则会出现segfaults)。做一个非常基本的基准测试,
(many迭代以最大限度地减少JIT的影响,也可以进行一些预热运行,然后开始测量时间或类似操作)
迭代解决方案似乎快了好几倍(在我的机器上大约快了5倍,见下文)。你可能会优化递归解决方案,并使其达到相当的速度。通过避免临时
coords
列表或通过将np.where
改变为> 0
。我不知道numba如何优化压缩的
np.where
。为了进一步优化,您可以考虑(和基准测试)使用显式嵌套for x in range(0, w): for y in range(0, h):
循环。为了与Nick提出的合并策略保持竞争力,我进一步优化了这个策略,选择了一些容易实现的目标:
zip
转换为continue
而不是np.where
的显式循环。decoded_image[idx]
存储在一个局部变量中(理想情况下应该没关系,但也没什么坏处)。w*h
或w*h*l
)。然后,我拼凑了一个基准来比较这四种方法(原始递归,旧迭代,新迭代,基于合并),将它们放在四个不同的模块中:
这将输出类似于
在我看来,改进的迭代方法可能更好。请注意,原始的基本基准测试似乎是递归变体的最差情况。总的来说,差别并不大。
测试的数组非常小(20³)。如果我用一个更大的数组(100³)和一个更小的n(20)进行测试,我大致得到以下结果(
rec
被省略,因为由于堆栈限制,它会segfault):迭代方法似乎仍然更有效。
hgc7kmma2#
这是我的尝试。我的想法是应用逐行连接组件算法。这个想法是,而不是试图洪水填充来识别组件的所有成员,你开始在每一行标记,并注意每次你到达一个矛盾。
然后,你把你的矛盾列表,并将它们合并到同一个类中。
This video解释了算法。
这样做的缺点是需要对输入数据进行两次处理。然而,它的优点是它具有非常好的缓存一致性。换句话说,它按顺序访问数组中的所有数据。
这与
scipy.ndimage.label()
内部遵循的算法相同,但我们可以更快,因为我们不需要为每个类调用一次。下面是我使用的代码。
它比我在随机数组上尝试的迭代版本快了大约30%。