我有一个大小为[n_rows, n_cols, n_channels]
的numpy数组。在我的代码中,我有一个循环,其中数组不断更新和裁剪:
def update(arr, row_idx, col_idx, ch_idx):
arr[row_idx, col_idx, ch_idx] += 1
arr[arr > 10] = 10
arr = np.array(n_rows, n_cols, n_channels)
while True:
update(arr, 0, 1, 2)
字符串
为了优化我的代码,我可以使用带索引列表的缓存,每N次迭代更新一次数组:
def update(arr, rows_list, cols_list, ch_list):
arr[rows_list, cols_list, ch_list] += 1
arr[arr > 10] = 10
arr = np.array(n_rows, n_cols, n_channels)
cache_length = 3
rows_list, cols_list, ch_list = [], [], []
while True:
rows_list.append(something1)
cols_list.append(something2)
ch_list.append(something3)
if len(row_list) == cache_length:
update(arr, rows_list, cols_list, ch_list)
rows_list, cols_list, ch_list = [], [], []
型
这可以节省时间,但可能会发生该高速缓存多次包含相同的数组索引,例如:
# arr[0, 0, 6] should be updated twice
update(arr, [0, 0, 2], [3, 3, 5], [6, 6, 6])
型
如何更改代码以使此优化工作?
1条答案
按热度按时间hrysbysz1#
您可以使用
numpy.unique
进行聚合:字符串
你可以通过只裁剪更新的值(而不是整个数组)来进一步优化:
型
范例:
型