numpy 如何在JAX中有效地将元素从另一个矩阵的索引处添加到另一个矩阵?

9rbhqvlz  于 2023-06-23  发布在  其他
关注(0)|答案(1)|浏览(133)

我试图创建一个矩阵的元素的分箱和,这样属于每个bin的元素被加在一起。
在Python方面,我有三个矩阵:一个1d数组,包含我想要添加的值,我称之为values;一个包含索引的2D数组,其中需要添加值矩阵的第n个值,我称之为index_mat*;以及一个我想要添加值的2D矩阵,我称之为sol**。
就Pythton而言,我想用for循环做的一个例子是:

import jax.numpy as jnp
import numpy as np

values = np.random.random((1000,2)) * 100

bins = jnp.linspace(0,100,1000)
indices = jnp.digitize(values, bins)

sol = jnp.zeros((len(bins), len(bins), 2))

for i in range(len(indices)):
   sol = sol.at[indices[i][0],indices[i][1]].add(values[i])

我想找到的是在JAX框架内实现元素总和的最佳方法,而这些元素可以以这种非最佳方式完成。
一般来说,我打算把它扩展到比(1000,1000,2)更多的维度,我希望找到一个更快的解决方案,希望不需要for循环。

izkcnapc

izkcnapc1#

您可以通过广播索引更新更有效地计算相同的内容:

sol = jnp.zeros((len(bins), len(bins), 2))
sol = sol.at[indices[:, 0], indices[:, 1]].add(values)

相关问题