jax批处理

gywdnpxw  于 2021-08-25  发布在  Java
关注(0)|答案(0)|浏览(241)

我有一个函数 compute(x) 哪里 x 是一个 jnp.ndarray . 现在,我想使用 vmap 将其转换为接受一批数组的函数 x[i] ,然后 jit 加快速度。 compute(x) 有点像:

def compute(x):
    # ... some code
    y = very_expensive_function(x)
    return y

但是,每个阵列 x[i] 有不同的长度。我可以很容易地解决这个问题,用尾随的零填充数组,使它们都具有相同的长度 Nvmap(compute) 可应用于具有形状的批次 (batch_size, N) .
然而,这样做会导致 very_expensive_function() 也可在每个数组的尾部零上调用 x[i] . 有办法修改吗 compute() 以致 very_expensive_function() 仅在一个片段上调用 x ,不妨碍 vmapjit ?

暂无答案!

目前还没有任何答案,快来回答吧!

相关问题