numpy 麻木对角线函数速度慢

toe95027  于 2022-11-10  发布在  其他
关注(0)|答案(1)|浏览(182)

我想实现连接4游戏作为一个爱好项目,我不知道,为什么在对角线上的匹配搜索是如此缓慢。在使用psstats分析我的代码时,我发现这是瓶颈。我想要建立一个电脑敌人,可以分析游戏中数千个未来的步骤,因此性能是一个问题。
有没有人知道,如何通过下面的代码来提高性能?我选择了NumPy来做这件事,因为我认为这会加快速度。问题是,我找不到一种方法来避免for循环。

import numpy as np

# Finds all the diagonal and off-diagonal-sequences in a 7x6 numpy array

def findseq(sm,seq=2,redyellow=1):
    matches=0
    # search in the diagonals
    # diags stores all the diagonals and off diagonals as rows of a matrix
    diags=np.zeros((1,6),dtype=np.int8)
    for k in range(-5,7):   
        t=np.zeros(6,dtype=np.int8)
        a=np.diag(sm,k=k).copy()
        t[:len(a)] += a
        s=np.zeros(6,dtype=np.int8)
        a=np.diag(np.fliplr(sm),k=k).copy()
        s[:len(a)] += a
        diags=np.concatenate(( diags,t[None,:],s[None,:]),axis=0)
    diags=np.delete(diags,0,0)
    # print(diags)
    # now, search for sequences
    Na=np.size(diags,axis=1)
    n=np.arange(Na-seq+1)[:,None]+np.arange(seq)
    seqmat=np.all(diags[:,n]==redyellow,axis=2)
    matches+=seqmat.sum()

    return matches

def randomdebug():
    # sm=np.array([[0,0,0,0,0,0,0],[0,0,0,0,0,0,0],[0,0,0,0,0,0,0],[0,0,0,0,0,0,0],[0,0,0,0,0,0,0],[0,0,2,1,1,0,0]])
    sm=np.random.randint(0,3,size=(6,7))
    return sm

# in my main program, I need to do this thousands of times

matches=[]
for i in range(1000):
    sm=randomdebug()
    matches.append(findseq(sm,seq=3,redyellow=1))
    matches.append(findseq(sm,seq=3,redyellow=2))
    # print(sm)
    # print(findseq(sm,seq=3))

以下是psstats

ncalls  tottime  percall  cumtime  percall filename:lineno(function)
     2000    1.965    0.001    4.887    0.002 Frage zu diag.py:4(findseq)
151002/103002    0.722    0.000    1.979    0.000 {built-in method numpy.core._multiarray_umath.implement_array_function}
    48000    0.264    0.000    0.264    0.000 {method 'diagonal' of 'numpy.ndarray' objects}
    48072    0.251    0.000    0.251    0.000 {method 'copy' of 'numpy.ndarray' objects}
    48000    0.209    0.000    0.985    0.000 twodim_base.py:240(diag)
    48000    0.179    0.000    1.334    0.000 <__array_function__ internals>:177(diag)
    50000    0.165    0.000    0.165    0.000 {built-in method numpy.zeros}

我刚接触巨蟒,所以请想象一个“无望的新手”的标签;-)

8ulbf1ek

8ulbf1ek1#

正如Andrey在评论中所述,代码调用了大量需要额外内存分配的NP函数。我相信这就是瓶颈。
我建议预测所有对角线的指数,因为它们在你的情况下不会有太大变化(矩阵形状保持不变,我猜序列可能会改变)。然后,您可以使用它们快速寻址对角线:

import numpy as np

known_diagonals = dict()
def diagonal_indices(h: int, w: int, length: int = 3) -> np.array:
    '''
    Returns array (shape diagonal_count x length) of diagonal indices
    of a flatten array
    '''
    # one of many ways to store precomputed function output
    # cleaner way would probably be to do this outside this function
    diagonal_indices_key = (h, w, length)
    if diagonal_indices_key in known_diagonals:
        return known_diagonals[diagonal_indices_key]

    diagonals_count = (h + 1 - length) * (w + 1 - length) * 2

    # default value is meant to ease process with cumsum:
    # adding h + 1 selects an index 1 down and 1 right, h - 1 index 1 down 1 left
    # firts half dedicated to right down diagonals
    diagonals = np.full((diagonals_count, length), w + 1, dtype=np.int32)
    # second half dedicated to left down diagonals
    diagonals[diagonals_count//2::] = w - 1

    # this could have been calculated mathematicaly
    flat_indices = np.arange(w * h).reshape((h, w))
    # print(flat_indices)

    # selects rectangle offseted by l - 1 from right and down edges
    diagonal_starts_rd = flat_indices[:h + 1 - length, :w + 1 - length]
    # selects rectangle offseted by l - 1 from left and down edges
    diagonal_starts_ld = flat_indices[:h + 1 - length, -(w + 1 - length):]

    # sets starts
    diagonals[:diagonals_count//2, 0] = diagonal_starts_rd.flatten()
    diagonals[diagonals_count//2::, 0] = diagonal_starts_ld.flatten()

    # sum triplets left to right
    # diagonals contains triplets (or vector of other length) of (start, h+-1, h+-1). cumsum makes diagonal indices
    diagonals = diagonals.cumsum(axis=1)

    # save ouput
    known_diagonals[diagonal_indices_key] = diagonals

    return diagonals

# Finds all the diagonal and off-diagonal-sequences in a 7x6 numpy array

def findseq(sm: np.array, seq: int = 2, redyellow: int = 1) -> int:
    matches = 0
    diagonals = diagonal_indices(*sm.shape, seq)

    seqmat = np.all(sm.flatten()[diagonals] == redyellow, axis=1)
    matches += seqmat.sum()

    return matches

def randomdebug():
    # sm=np.array([[0,0,0,0,0,0,0],[0,0,0,0,0,0,0],[0,0,0,0,0,0,0],[0,0,0,0,0,0,0],[0,0,0,0,0,0,0],[0,0,2,1,1,0,0]])
    sm=np.random.randint(0,3,size=(6,7))
    return sm

# in my main program, I need to do this thousands of times

matches=[]
for i in range(1000):
    sm=randomdebug()
    matches.append(findseq(sm,seq=3,redyellow=1))
    matches.append(findseq(sm,seq=3,redyellow=2))
    # print(sm)
    # print(findseq(sm,seq=3))

相关问题