我遇到了以下问题,我无法设法用JAX编写一个可抖动且高效的解决方案。
我有一组元素。其中一些元素被包括在内(基于一个条件,现在这个条件并不重要)。包含的元素用1表示,不包含的元素用0表示。例如,数组arr = jnp.array([1, 0, 0, 0, 0, 0])
表示我有6个元素,其中第一个元素是基于我的条件包含的。
这些元素被分组为子集。我有第二个数组,它指示每个子集在第一个数组arr
中的开始位置。例如,数组subsets = jnp.array([0, 2])
表示第一个子集从位置0开始,第二个子集从位置2开始。
现在,如果包含一个基于arr
的元素,我希望将所有元素都包含在同一个子集中。在本例中,输出应为[1, 1, 0, 0, 0, 0]
。
我试过用jax.lax.fori_loop
,但它很慢。
@jax.jit
def select_subsets(arr, subsets):
new_arr = arr.copy()
n_resid = subsets.shape[0]
indices = jnp.arange(arr.shape[0])
def func(i, new_arr):
start = subsets[i]
stop = subsets[i+1]
arr_sliced = jnp.where((indices >= start) & (indices < stop), arr, 0.0)
sum_ = jnp.sum(arr_sliced)
new_arr = jnp.where(sum_ > 0.5, jnp.where((indices >= start) & (indices < stop), 1, new_arr), new_arr)
return new_arr
new_arr = jax.lax.fori_loop(0, n_resid-1, func, new_arr)
return new_arr
字符串
如果我使用subsets
,最后一个元素等于arr
,subsets = jnp.array([0, 2, 6])
中的元素数,则此函数有效。
然后我想写一个矢量化的版本(使用jax.numpy
操作),但我不能做到这一点。
有没有一个JAX大师可以帮我解决这个问题?
多谢了!
1条答案
按热度按时间ubbxdtey1#
这是一个矢量化的版本。它示例化了一个形状为
len(subsets) x len(arr)
的遮罩,根据这些值的大小,这可能是不需要的。字符串
我在长度为512和32个子集的数组上对基于循环的版本进行了计时:
型