python 如何在JAX中对不等长数组列表上的函数进行矢量化

qyzbxkaa  于 2022-12-10  发布在  Python
关注(0)|答案(1)|浏览(126)

这是我所面临真实的大问题的一个最小的例子。考虑下面的函数:

import jax.numpy as jnp
def test(x):
    return jnp.sum(x)

我试着通过以下方法将其矢量化:

v_test = jax.vmap(test)

test的输入如下所示:

x1 = jnp.array([1,2,3])
x2 = jnp.array([4,5,6,7])
x3 = jnp.array([8,9])
x4 = jnp.array([10])

我对v_test的输入是:

x = [x1, x2, x3, x4]

如果我尝试:

v_test(x)

我得到下面的错误:

ValueError: vmap got inconsistent sizes for array axes to be mapped:
the tree of axis sizes is:
([3, 4, 2, 1],)

有没有办法在一系列长度不等的数组上对test进行矢量化?我可以通过填充来避免这种情况,这样数组就有相同的长度,但是,填充是不需要的。

uttx8gqw

uttx8gqw1#

JAX不支持不规则数组(即每一行的元素数不同的数组),因此目前还没有办法对这类数据使用vmap。最好的办法可能是使用Python for循环:

y = [test(xi) for xi in x]

或者,您也可以用segment_sum或类似的运算来表示您想要的运算。例如:

segments = jnp.concatenate([i * jnp.ones_like(xi) for i, xi in enumerate(x)])
result = jax.ops.segment_sum(jnp.concatenate(x), segments)
print(result)
# [ 6 22 17 10]

另一种可能性是填充输入数组,以便它们可以适合标准的非粗糙2D数组。

相关问题