numpy 如果至少选择了一个元素,则选择子集的所有元素(JAX)

f3temu5u  于 2023-08-05  发布在  其他
关注(0)|答案(1)|浏览(83)

我遇到了以下问题,我无法设法用JAX编写一个可抖动且高效的解决方案。
我有一组元素。其中一些元素被包括在内(基于一个条件,现在这个条件并不重要)。包含的元素用1表示,不包含的元素用0表示。例如,数组arr = jnp.array([1, 0, 0, 0, 0, 0])表示我有6个元素,其中第一个元素是基于我的条件包含的。
这些元素被分组为子集。我有第二个数组,它指示每个子集在第一个数组arr中的开始位置。例如,数组subsets = jnp.array([0, 2])表示第一个子集从位置0开始,第二个子集从位置2开始。
现在,如果包含一个基于arr的元素,我希望将所有元素都包含在同一个子集中。在本例中,输出应为[1, 1, 0, 0, 0, 0]
我试过用jax.lax.fori_loop,但它很慢。

@jax.jit                                                                        
def select_subsets(arr, subsets):                 
    new_arr = arr.copy()            
    n_resid = subsets.shape[0]    
    indices = jnp.arange(arr.shape[0])       
                                                                                    
    def func(i, new_arr):    
        start = subsets[i]    
        stop = subsets[i+1]    
        arr_sliced = jnp.where((indices >= start) & (indices < stop), arr, 0.0)    
        sum_ = jnp.sum(arr_sliced)    
        new_arr = jnp.where(sum_ > 0.5, jnp.where((indices >= start) & (indices < stop), 1, new_arr), new_arr)    
        return new_arr                          
                                                 
    new_arr = jax.lax.fori_loop(0, n_resid-1, func, new_arr)    
    return new_arr

字符串
如果我使用subsets,最后一个元素等于arrsubsets = jnp.array([0, 2, 6])中的元素数,则此函数有效。
然后我想写一个矢量化的版本(使用jax.numpy操作),但我不能做到这一点。
有没有一个JAX大师可以帮我解决这个问题?
多谢了!

ubbxdtey

ubbxdtey1#

这是一个矢量化的版本。它示例化了一个形状为len(subsets) x len(arr)的遮罩,根据这些值的大小,这可能是不需要的。

@jax.jit                                                                             
def vectorized_select_subsets(arr, subsets):                         
    l, = arr.shape               
                               
    indices = jnp.arange(l)[None, :]
                
    # Broadcast to mask of shape (n_subsets, input_length)
    subset_masks = (
        (indices >= subsets[:-1, None])
        & (indices < subsets[1:, None])        
    )                                                           
                                                    
    # Shape (n_subsets,) array indicating whether each subset is included
    include_subset = jnp.any(subset_masks & arr[None, :], axis=1)          
                                                          
    # Reduce down columns 
    result = jnp.any(subset_masks & include_subset[:, None], axis=0).astype(jnp.int32)   
    return result

字符串
我在长度为512和32个子集的数组上对基于循环的版本进行了计时:

Loop: 6254.647 it/s
Vectorized: 37940.335 it/s

相关问题