按条件就地分区numpy数组

lhcgjxsq  于 2023-04-12  发布在  其他
关注(0)|答案(4)|浏览(90)

我有一个1d数组的u64整数。我需要分区的基础上给定的位到位。
在纯python中,两个指针问题很容易解决(类似于quicksort中的分区),但我想知道是否可以使用numpy有效地解决这个问题。
假设我有:

arr = np.arange(720, dtype=np.uint64)  # a lot of 64bit unsigned ints
np.random.shuffle(arr)  # in unknown order
left = arr[arr & 32 == 0]  # all elements having 5th bit 0 in any order
right = arr[arr & 32 == 32]  # all elements having 5th bit 1 in any order
arr[:] = np.concatenate([left, right])

我还需要知道分区的索引(也就是上面的len(left))。

bvhaajcl

bvhaajcl1#

你可以使用argsort来为数组分配重新排序的索引(我使用最后一位而不是&32,以使结果更容易理解evens/odds):

import numpy as np

arr = np.arange(20)  # a lot of 64bit unsigned ints
np.random.shuffle(arr)  # in unknown order

print(arr)
# [ 2  6 15 10 12  5 11  0 14 18 19  3  7  4  1 17  9  8 16 13]
arr[:] = arr[np.argsort(arr&1)]
print(arr)
# [ 2  8  4 16 14  0 18 12 10  6  5 19  3  7 15  1 17  9 11 13]

如果不进行排序,您可以设置左右位置的掩码,通过计算掩码中的True值来测量左分区的大小,然后使用掩码和逆掩码部分分配下标。

mask = arr&1 == 0
left = np.sum(mask)
arr[:left],arr[left:] = arr[mask],arr[~mask].copy()
  • 请注意,第二部分必须使用.copy(),因为在第二次赋值之前,arr的内容已经发生了变化。*
wtlkbnrh

wtlkbnrh2#

一个可能的解决方案可以是np.where(),以获得满足条件的元素的索引。

import numpy as np

arr = np.arange(720, dtype=np.uint64)
np.random.shuffle(arr)

mask = arr & 32 == 0
left_indices = np.where(mask)[0]
right_indices = np.where(~mask)[0]

# Use the indices to partition the array
left = arr[left_indices]
right = arr[right_indices]
arr[:] = np.concatenate([left, right])

# Get the length of the left partition
partition_index = len(left_indices)
print("Partition index:", partition_index)
cidc1ykv

cidc1ykv3#

可以使用np.argsort

result = arr[np.argsort(arr & 32)]

排序arr & 32是对两个值进行排序:0和32,这是您需要的分区。如果您需要分区索引,

(result & 32).argmax()

这将返回数组的最大值的第一个索引,即排序数据中的零的数量。

blpfk2vs

blpfk2vs4#

社区中有很多很酷的解决方案(谢谢大家!),但看起来使用纯numpy不可能保持O(1)空间。
我决定实现自己的分区例程并使用numba jit来接近numpy性能。

import numpy as np
from numba import njit

@njit
def partition_by_bit(nparr, bit_to_partition_on):
    n = nparr.shape[0]
    i = 0
    j = n - 1
    mask = 1 << bit_to_partition_on

    while True:
        while i < n and (nparr[i] & mask == 0):
            i += 1
        while j >= 0 and (nparr[j] & mask == mask):
            j -= 1
        if i >= j:
            break
        nparr[i], nparr[j] = nparr[j], nparr[i]

    return i

arr = np.arange(720, dtype=np.uint64)
np.random.shuffle(arr)

partition_index = partition_by_bit(arr, 5)  # O(1) space, O(n) time

相关问题