这是我所面临真实的大问题的一个最小的例子。考虑下面的函数:
import jax.numpy as jnp
def test(x):
return jnp.sum(x)
我试着通过以下方法将其矢量化:
v_test = jax.vmap(test)
test
的输入如下所示:
x1 = jnp.array([1,2,3])
x2 = jnp.array([4,5,6,7])
x3 = jnp.array([8,9])
x4 = jnp.array([10])
我对v_test
的输入是:
x = [x1, x2, x3, x4]
如果我尝试:
v_test(x)
我得到下面的错误:
ValueError: vmap got inconsistent sizes for array axes to be mapped:
the tree of axis sizes is:
([3, 4, 2, 1],)
有没有办法在一系列长度不等的数组上对test
进行矢量化?我可以通过填充来避免这种情况,这样数组就有相同的长度,但是,填充是不需要的。
1条答案
按热度按时间uttx8gqw1#
JAX不支持不规则数组(即每一行的元素数不同的数组),因此目前还没有办法对这类数据使用
vmap
。最好的办法可能是使用Pythonfor
循环:或者,您也可以用
segment_sum
或类似的运算来表示您想要的运算。例如:另一种可能性是填充输入数组,以便它们可以适合标准的非粗糙2D数组。