我想在Jax中实现类似下面的Python函数,并使用对vmap
的调用 Package 它。我希望它是完全反向模式微分(相对于x
)使用grad()
,即使在vmap之后。
def f(x,kmax):
return sum ([x**k for k in range(1,kmax+1)])
(This是函数的一个有意简化的版本;我意识到在这种情况下我可以使用几何级数的封闭形式表达式;遗憾的是,我试图实现的实际函数没有我所知道的闭合形式的总和。)
有什么办法可以做到这一点吗?好像有 * 必须 *;但是如果kmax
是动态的,则fori_loop
不是反向模式可微的,jax.lax.scan
需要一个静态形状的数组,否则它将抛出ConcretizationTypeError
s,类似地,像range
这样的Python原语(如上所述)如果 Package 在vmap
中,则抛出TracerIntegerConversionError
。
我想我理解需要固定形状的数组的限制,但是我用过的每个autodiff框架都允许你动态地 * 以某种方式 * 构造任意大小的表达式。在一个变化的整数范围内求和是一个非常基本的数学工具。如何在Jax中实现这一点?
编辑以重新聚焦问题定义(问题更多的是vmap而不是grad),并提供以下示例。
具体来说,这是我希望能够做到的
import jax
def f(x,kmax):
return sum ([x**k for k in range(1,kmax+1)])
fmap = jax.vmap(f,in_axes=(None,-1))
x = 3.
kmaxes = jax.numpy.array([1,2,3])
print(fmap(x,kmaxes))
fmap_sum = lambda k,kmaxes:jax.numpy.sum(fmap(k,kmaxes))
print(fmap_sum(x,kmaxes))
print(jax.grad(fmap_sum)(x,kmaxes))
这将在range(1,kmax+1)
处引发TracerIntegerConversionError。我想让它做的事情是这样的:
import jax
def f(x,kmax):
return sum ([x**k for k in range(1,kmax+1)])
def fmap(x,kmaxes):
return [f(x,kmax) for kmax in kmaxes]
x = 3.
kmaxes = jax.numpy.array([1,2,3])
print(fmap(x,kmaxes))
def fmap_sum(x,kmaxes):
return sum(fmap(x,kmaxes))
print(fmap_sum(x,kmaxes))
print(jax.grad(fmap_sum)(x,kmaxes))
这给出了正确的结果(但是失去了VMAP的并行化和加速)。
1条答案
按热度按时间dauxcl2d1#
首先,要使您的函数与
vmap
兼容,您需要将Python控制流替换为jax.lax
控制流操作。在这种情况下,lax.fori_loop
似乎适用:但是由于循环的大小是动态的,这与反向模式autodiff不兼容:
要解决这个问题,可以修改函数,使其使用静态循环大小。这里有一种方法可以做到这一点: