使用索引列表更新一个numpy数组,该索引列表具有多次出现的相同索引

cbeh67ev  于 11个月前  发布在  其他
关注(0)|答案(1)|浏览(102)

我有一个大小为[n_rows, n_cols, n_channels]的numpy数组。在我的代码中,我有一个循环,其中数组不断更新和裁剪:

def update(arr, row_idx, col_idx, ch_idx):
    arr[row_idx, col_idx, ch_idx] += 1
    arr[arr > 10] = 10

arr = np.array(n_rows, n_cols, n_channels)
while True:
    update(arr, 0, 1, 2)

字符串
为了优化我的代码,我可以使用带索引列表的缓存,每N次迭代更新一次数组:

def update(arr, rows_list, cols_list, ch_list):
        arr[rows_list, cols_list, ch_list] += 1
        arr[arr > 10] = 10

arr = np.array(n_rows, n_cols, n_channels)
cache_length = 3
rows_list, cols_list, ch_list = [], [], []
while True:
    rows_list.append(something1)
    cols_list.append(something2)
    ch_list.append(something3)
    if len(row_list) == cache_length:
        update(arr, rows_list, cols_list, ch_list)
        rows_list, cols_list, ch_list = [], [], []


这可以节省时间,但可能会发生该高速缓存多次包含相同的数组索引,例如:

# arr[0, 0, 6] should be updated twice
update(arr, [0, 0, 2], [3, 3, 5], [6, 6, 6])


如何更改代码以使此优化工作?

hrysbysz

hrysbysz1#

您可以使用numpy.unique进行聚合:

def update(arr, row_idx, col_idx, ch_idx):
    idx, cnt = np.unique([row_idx, col_idx, ch_idx],
                         return_counts=True, axis=1)
    arr[tuple(idx)] += cnt
    arr[arr > 10] = 10

字符串
你可以通过只裁剪更新的值(而不是整个数组)来进一步优化:

def update(arr, row_idx, col_idx, ch_idx):
    idx, cnt = np.unique([row_idx, col_idx, ch_idx],
                         return_counts=True, axis=1)
    idx = tuple(idx)
    arr[idx] = np.clip(arr[idx]+cnt, -np.inf, 10)


范例:

arr = np.zeros((2, 3, 4), dtype='int')
update(arr, [0, 0, 1], [1, 1, 2], [3, 3, 3])

# arr
array([[[0, 0, 0, 0],
        [0, 0, 0, 2],
        [0, 0, 0, 0]],

       [[0, 0, 0, 0],
        [0, 0, 0, 0],
        [0, 0, 0, 1]]])

相关问题