我正在尝试写一个函数,检查在一个二维数组中,在行、列和对角线方向上,是否至少有一个数字是“一行五个”。要检查的数组可以是大小>= 5的任何方阵,但我最有可能使用的是7 x7的。
例如,下面的矩阵在一个序列中有3次出现五个1的模式(至少检测一个就足够了)。一个在第一列,一个对角地从(0,6)到(5,1),另一个对角地从(1,0)到(5,6)。
A = np.array(
[
[0, 1, 0, 0, 0, 0],
[1, 0, 1, 0, 1, 0],
[1, 0, 0, 1, 0, 0],
[1, 0, 1 ,0, 1, 0],
[1, 1, 0, 0, 0, 1],
[1, 0, 0, 0, 0, 0]
]
)
字符串
我在这方面的尝试如下所示
import numpy as np
import cProfile
import pstats
import time
class Board:
def __init__(self, size):
self.data = np.zeros((size, size), dtype=np.byte)
self.size = size
# make sliding window of size 5 for each row and col
self.rowWindows = np.lib.stride_tricks.sliding_window_view(self.data, window_shape=(1,5))
self.colWindows = np.lib.stride_tricks.sliding_window_view(np.transpose(self.data), window_shape=(1,5))
# make sliding window for both diagonal directions
# storing them as object array since the diagonals windows have different sizes
self.antiDiagonalWindow = np.array(
[
np.lib.stride_tricks.sliding_window_view(np.fliplr(self.data).diagonal(offset=i), window_shape=(5,))
for i in range(-self.size + 5, self.size - 5 + 1, 1)
],
dtype=object,
)
self.diagonalWindow = np.array(
[
np.lib.stride_tricks.sliding_window_view(self.data.diagonal(offset=i), window_shape=(5,))
for i in range(-self.size + 5, self.size - 5 + 1, 1)
],
dtype=object,
)
def hasFiveInRow(self, value):
return (
np.any(np.all(self.rowWindows == value, -1),)
or np.any(np.all(self.colWindows == value,-1), )
# have to use concat to turn sliding window views into 2d array
# since diagonals have different sizes
or np.any(np.all(np.concatenate(self.antiDiagonalWindow) == value, -1), )
or np.any(np.all(np.concatenate(self.diagonalWindow) == value, -1), )
)
def benchMark():
b = Board(size=7)
b.data[:]=np.random.randint(low=0, high=3, size=(7,7))
for i in range(100_000):
val = b.hasFiveInRow(1)
# t0 = time.time()
# benchMark()
# print(time.time() - t0)
with cProfile.Profile() as p:
benchMark()
res = pstats.Stats(p)
res.sort_stats(pstats.SortKey.TIME)
res.print_stats()
型
结果性能不是太差,但我想提高它,如果可能的话,因为我使用它作为一个游戏ai树搜索的一部分,将不得不调用函数非常大量的次数。我认为np.any(np.all(windows))
是不理想的,因为它必须创建许多布尔数组减少到一个单一的值。
cProfile日志显示了大量对'reduce'、'dictcomp'和_wrapreduction'等的调用,这些调用需要很长时间才能完成。
有没有更好的方法来寻找这个模式呢?我只需要检查这个模式是否以布尔值的形式至少出现过一次,尽管得到确切的位置和出现的次数会很好。
任何帮助将不胜感激!
2条答案
按热度按时间uujelgoq1#
我认为这是
numba
闪耀的场景:字符串
快速基准:
型
在我的计算机上打印(AMD 5700 x,Python 3.11):
型
所以~ 60倍加速。
1mrurvl12#
效率是关键--循环经常击败花哨的函数。我将“五行”的逻辑修改为一个简单的嵌套迭代。
它不再使用滑动窗口,而是手动检查每个单元格与其八个相邻单元格之间的关系,没有冗余。
将模式检查 Package 在一个干净的函数中,以保持主循环整洁。
结果证明了这一点--原来的几分钟现在变成了几秒钟。小的调整可以带来大的性能提升。
字符串