如何用numpy向量化一个函数,以便它可以应用于3d数组,因为这个函数需要访问数组的某些单元格?

6kkfgxo0  于 2023-10-19  发布在  其他
关注(0)|答案(2)|浏览(117)

我有一个计算,在这个计算中,我需要遍历一个3d numpy数组的元素,并将它们添加到数组的第二维中的值(跳过该维中的值)。它类似于这个典型的模仿生殖的例子:

import numpy as np

data = np.array([
    [[1, 1, 1], [10, 10, 10], [1, 1, 1]],
    [[2, 2, 2], [20, 20, 20], [2, 2, 2]],
    [[3, 3, 3], [30, 30, 30], [3, 3, 3]] ])

def process_data(const_idx, data, i, j, k):
    if const_idx != j:
        # PROBLEM: how can I access this value if this function is vectorized?
        value_to_add = data[i][const_idx][k]
        data[i][j][k] += value_to_add

const_idx = 1
for i in range(data.shape[0]):
    for j in range(data.shape[1]):
        for k in range(data.shape[2]):
            process_data(const_idx, data, i, j, k)

print(data)

在这种情况下,预期输出为:

[[[11 11 11]
  [10 10 10]
  [11 11 11]]

 [[22 22 22]
  [20 20 20]
  [22 22 22]]

 [[33 33 33]
  [30 30 30]
  [33 33 33]]]

上面的代码可以工作,但对于大型数组来说非常慢。我想把这个函数向量化。
我的第一个刺是这样的:

def process_data(val, data, const_idx):
    # PROBLEM: How can I access this value given that I do not have access to the i / j / k coordinates val came from?
    value_to_add = ...
    
    # PROBLEM: I cannot make this check either since I dont know the j index being processed here
    if const_idx != j:
        return val + value_to_add
    else:
        return val

vfunc = np.vectorize(process_data)

result = vfunc(data, data, const_idx)

print(result)

我如何才能做到这一点,或者矢量化可能不是答案?

pbossiut

pbossiut1#

const_idx指向作为加法因子的行的索引。
您可以使用以下方法在所需尺寸上快速执行在位添加:

def add_by_idx(arr, idx):
    r = np.arange(arr.shape[1])  # row indices
    arr[:, r[r != idx], :] += arr[:, [idx], :]

add_by_idx(data, 1)
print(data)
[[[11 11 11]
  [10 10 10]
  [11 11 11]]

 [[22 22 22]
  [20 20 20]
  [22 22 22]]

 [[33 33 33]
  [30 30 30]
  [33 33 33]]]
ufj5ltwl

ufj5ltwl2#

这里是另一种方法:在所有地方添加行,然后从索引中减去它:

def process_data(data, const_idx):
    to_add = data[:, const_idx, :].copy()
    data += to_add[:, None, :]
    data[:, const_idx] -= to_add

process_data(data, const_idx=1)
print(data)

输出量:

[[[11 11 11]
  [10 10 10]
  [11 11 11]]

 [[22 22 22]
  [20 20 20]
  [22 22 22]]

 [[33 33 33]
  [30 30 30]
  [33 33 33]]]

相关问题