我试图创建一个矩阵的元素的分箱和,这样属于每个bin的元素被加在一起。
在Python方面,我有三个矩阵:一个1d数组,包含我想要添加的值,我称之为values;一个包含索引的2D数组,其中需要添加值矩阵的第n个值,我称之为index_mat*;以及一个我想要添加值的2D矩阵,我称之为sol**。
就Pythton而言,我想用for循环做的一个例子是:
import jax.numpy as jnp
import numpy as np
values = np.random.random((1000,2)) * 100
bins = jnp.linspace(0,100,1000)
indices = jnp.digitize(values, bins)
sol = jnp.zeros((len(bins), len(bins), 2))
for i in range(len(indices)):
sol = sol.at[indices[i][0],indices[i][1]].add(values[i])
我想找到的是在JAX框架内实现元素总和的最佳方法,而这些元素可以以这种非最佳方式完成。
一般来说,我打算把它扩展到比(1000,1000,2)更多的维度,我希望找到一个更快的解决方案,希望不需要for循环。
1条答案
按热度按时间izkcnapc1#
您可以通过广播索引更新更有效地计算相同的内容: