在NumPy中高效迭代,其中下一次迭代取决于上一次迭代

wr98u20j  于 2023-08-05  发布在  其他
关注(0)|答案(1)|浏览(109)

我正在尝试使用NumPy模拟任意精度的二进制算术。举一个简单的例子,我有一段代码(在基本的非NumPy Python中),它将一个二进制数加一,其中二进制数表示为从最低到最高有效位的0和1的列表(因此它们从右到左读取为二进制数):

def increment_bits(bits):
    """
    Returns binary representation corresponding to bits + 1
    where bits is a list of 0s and 1s.
    
    >>> increment_bits([1, 1, 1, 0, 1])   # 23 + 1 == 24
    [0, 0, 0, 1, 1]
    >>> increment_bits([1, 1, 1])         #  7 + 1 ==  8   <- an extra bit now needed
    [0, 0, 0, 1]
    """
    new_bits = bits[:]
    for i, v in enumerate(new_bits):
        if v:
            new_bits[i] = 0
        else:
            new_bits[i] = 1
            return new_bits
    # if we have made it here, then there is a final "carry"
    new_bits.append(1)
    return new_bits

字符串
我的目标是使用NumPy实现相同的效果,但速度更快,但我是NumPy的新手,不确定迭代NumPy数组以实现所需效果的最快(或最NumPy)方法。

vx6bjr1n

vx6bjr1n1#

我同意评论者的看法,Numpy数组不是一个好方法,但我还是想给予一下。主要思想是巧妙地创建一个掩码,通过使用累积乘积来添加一个。这将创建一个1的数组,直到第一个0。然后我们需要移位1,然后将第0个索引设置为1,所以我们在每个索引上加1,直到第一个0之后。
希望numpy可能有一个比O(n)更快的累积乘积实现,但结果仍然比迭代列表慢。此外,转换成numpy数组总是有开销的。

def numpy_bit(bits):
    mask = np.cumprod(bits)
    mask = np.roll(mask, 1)
    mask[0] = 1

    res = (bits + np.ones(len(bits)) * mask) % 2
    if np.all(res == 0):
        return np.append(res, 1)
    else:
        return res

字符串

相关问题