python 如何在Jax中实现动态范围内的vmappable和?

r3i60tvu  于 2023-05-27  发布在  Python
关注(0)|答案(1)|浏览(117)

我想在Jax中实现类似下面的Python函数,并使用对vmap的调用 Package 它。我希望它是完全反向模式微分(相对于x)使用grad(),即使在vmap之后。

def f(x,kmax):
  return sum ([x**k for k in range(1,kmax+1)])

(This是函数的一个有意简化的版本;我意识到在这种情况下我可以使用几何级数的封闭形式表达式;遗憾的是,我试图实现的实际函数没有我所知道的闭合形式的总和。)
有什么办法可以做到这一点吗?好像有 * 必须 *;但是如果kmax是动态的,则fori_loop不是反向模式可微的,jax.lax.scan需要一个静态形状的数组,否则它将抛出ConcretizationTypeError s,类似地,像range这样的Python原语(如上所述)如果 Package 在vmap中,则抛出TracerIntegerConversionError
我想我理解需要固定形状的数组的限制,但是我用过的每个autodiff框架都允许你动态地 * 以某种方式 * 构造任意大小的表达式。在一个变化的整数范围内求和是一个非常基本的数学工具。如何在Jax中实现这一点?
编辑以重新聚焦问题定义(问题更多的是vmap而不是grad),并提供以下示例。
具体来说,这是我希望能够做到的

import jax

def f(x,kmax):
  return sum ([x**k for k in range(1,kmax+1)])

fmap = jax.vmap(f,in_axes=(None,-1))

x = 3.
kmaxes = jax.numpy.array([1,2,3])

print(fmap(x,kmaxes))

fmap_sum = lambda k,kmaxes:jax.numpy.sum(fmap(k,kmaxes))

print(fmap_sum(x,kmaxes))
print(jax.grad(fmap_sum)(x,kmaxes))

这将在range(1,kmax+1)处引发TracerIntegerConversionError。我想让它做的事情是这样的:

import jax

def f(x,kmax):
  return sum ([x**k for k in range(1,kmax+1)])

def fmap(x,kmaxes):
  return [f(x,kmax) for kmax in kmaxes]

x = 3.
kmaxes = jax.numpy.array([1,2,3])

print(fmap(x,kmaxes))

def fmap_sum(x,kmaxes):
  return sum(fmap(x,kmaxes))

print(fmap_sum(x,kmaxes))
print(jax.grad(fmap_sum)(x,kmaxes))

这给出了正确的结果(但是失去了VMAP的并行化和加速)。

dauxcl2d

dauxcl2d1#

首先,要使您的函数与vmap兼容,您需要将Python控制流替换为jax.lax控制流操作。在这种情况下,lax.fori_loop似乎适用:

def f1(x, k):
  def body_fun(i, val):
    return val + x ** i
  return jax.lax.fori_loop(1, k + 1, body_fun, jnp.zeros_like(x))

f1map = jax.vmap(f1, (None, 0))
print(f1map(x, kmaxes))
# [ 3. 12. 39.]

但是由于循环的大小是动态的,这与反向模式autodiff不兼容:

jax.jacrev(f1map)(x, kmaxes)
# ValueError: Reverse-mode differentiation does not work for lax.while_loop or lax.fori_loop. Try using lax.scan instead.

要解决这个问题,可以修改函数,使其使用静态循环大小。这里有一种方法可以做到这一点:

def f2(x, k, kmax):  # kmax should be static
  def body_fun(i, val):
    return val + jnp.where(i <= k, x ** i, 0)
  return jax.lax.fori_loop(1, kmax + 1, body_fun, jnp.zeros_like(x))

f2map = jax.vmap(f2, (None, 0, None))

print(f2map(x, kmaxes, kmaxes.max()))  # compatible with vmap
# [ 3. 12. 39.]

print(jax.jacrev(f2map)(x, kmaxes, kmaxes.max()))  # and with reverse-mode autodiff
# [ 1.  7. 34.]

相关问题