我最近写了一个脚本来转换BGR数组的[0,1]浮点数到HSL和回来。我把它放在Code Review上。目前有一个答案,但它不会提高性能。
我已经对我的代码进行了cv2.cvtColor
基准测试,发现我的代码效率很低,所以我想用Numba编译代码,使它运行得更快。
我尝试用@nb.njit(cache=True, fastmath=True)
Package 所有函数,但这不起作用。
所以我测试了我单独使用过的每个NumPy语法和NumPy函数,发现有两个函数不适用于Numba。
我需要找到每个像素的最大通道(np.max(img, axis=-1)
)和每个像素的最小通道(np.max(img, axis=-1)
),axis
参数不适用于Numba。
我试着在谷歌上搜索这个,但我发现的唯一相关的东西是this,但它只实现了np.any
和np.all
,并且只适用于二维数组,而这里的数组是三维的。
我可以写一个基于for循环的解决方案,但我不会写它,因为它注定是低效的,并且首先违背了使用NumPy和Numba的目的。
最小可重现性示例:
import numba as nb
import numpy as np
@nb.njit(cache=True, fastmath=True)
def max_per_cell(arr):
return np.max(arr, axis=-1)
@nb.njit(cache=True, fastmath=True)
def min_per_cell(arr):
return np.min(arr, axis=-1)
img = np.random.random((3, 4, 3))
max_per_cell(img)
min_per_cell(img)
例外情况:
In [2]: max_per_cell(img)
---------------------------------------------------------------------------
TypingError Traceback (most recent call last)
Cell In[2], line 1
----> 1 max_per_cell(img)
File C:\Python310\lib\site-packages\numba\core\dispatcher.py:468, in _DispatcherBase._compile_for_args(self, *args, **kws)
464 msg = (f"{str(e).rstrip()} \n\nThis error may have been caused "
465 f"by the following argument(s):\n{args_str}\n")
466 e.patch_message(msg)
--> 468 error_rewrite(e, 'typing')
469 except errors.UnsupportedError as e:
470 # Something unsupported is present in the user code, add help info
471 error_rewrite(e, 'unsupported_error')
File C:\Python310\lib\site-packages\numba\core\dispatcher.py:409, in _DispatcherBase._compile_for_args.<locals>.error_rewrite(e, issue_type)
407 raise e
408 else:
--> 409 raise e.with_traceback(None)
TypingError: Failed in nopython mode pipeline (step: nopython frontend)
No implementation of function Function(<function amax at 0x0000014E306D3370>) found for signature:
>>> amax(array(float64, 3d, C), axis=Literal[int](-1))
There are 2 candidate implementations:
- Of which 2 did not match due to:
Overload in function 'npy_max': File: numba\np\arraymath.py: Line 541.
With argument(s): '(array(float64, 3d, C), axis=int64)':
Rejected as the implementation raised a specific error:
TypingError: got an unexpected keyword argument 'axis'
raised from C:\Python310\lib\site-packages\numba\core\typing\templates.py:784
During: resolving callee type: Function(<function amax at 0x0000014E306D3370>)
During: typing of call at <ipython-input-1-b3894b8b12b8> (10)
File "<ipython-input-1-b3894b8b12b8>", line 10:
def max_per_cell(arr):
return np.max(arr, axis=-1)
^
如何解决这一问题?
2条答案
按热度按时间von4xj4u1#
不使用
np.max()
,而是使用循环来实现它是相当简单的:对它进行基准测试,结果证明它比
np.max(arr, axis=-1)
快16倍。在进行基准测试时,我做了以下假设:
np.max()
的速度更快,只需要15 ms。参见Check if numpy array is contiguous?了解如何判断数组是C顺序还是Fortran顺序。np.random.random((3, 4, 3))
的例子是C连续的。np.max(arr, axis=-1)
进行比较,因为它无法真正优化对NumPy函数的单个调用。yrdbyhpb2#
CHW实现
根据@NickODell回答的评论,这里是一个更快的SIMD友好型解决方案,当图像使用 *CWH布局 * 时(根据@NickODell的要求):
以下是我的机器上的结果(i5- 9600 KF + 40 GiB/s RAM在Windows上):
这意味着这个实现在32位浮点数上快了3倍。前者是标量,后者使用SIMD指令。前者也会因页面错误而变慢,而后者则不会。SIMD版本如果不是内存受限的话,可能会更快。数据类型越小,SIMD实现越快。使用8位无符号整数,SIMD版本大约快9倍,* 仍然内存受限 *:
HWC实现
请注意,只需将
max_per_cell_nb_faster
的单行内容替换为以下内容即可生成@NickODell版本,以支持HWC映像:然而,这个版本比@NickODell的解决方案慢一点(尽管更简单),因为它在内部不使用SIMD指令:
事实上,LLVM优化器AFAIK还不能对这种访问模式进行向量化(因为这非常困难,而且即使手动小心地完成,它的效率通常也低于使用CWH)。