在我的应用程序中,我需要取一个函数的n阶混合导数。然而,我发现torch.autograd.grad的计算时间随着n的增加而呈指数级增长。这是预料之中的吗?有什么办法可以绕过它吗?
这是我对函数self.F(from R^n -> R^1)求微分的代码:
def differentiate(self, x):
x.requires_grad_(True)
xi = [x[...,i] for i in range(x.shape[-1])]
dyi = self.F(torch.stack(xi, dim=-1))
for i in range(self.dim):
start_time = time.time()
dyi = torch.autograd.grad(dyi.sum(), xi[i], retain_graph=True, create_graph=True)[0]
grad_time = time.time() - start_time
print(grad_time)
return dyi
这些是上面循环的每次迭代打印的时间:
0.0037012100219726562
0.005133152008056641
0.008165121078491211
0.019922733306884766
0.059255123138427734
0.1910409927368164
0.6340939998626709
2.1612229347229004
11.042078971862793
我假设这是因为计算图的大小正在增加?有什么办法可以解决这个问题吗?我想我也许可以通过使用torch. func. grad采用函数方法(大概可以避免对计算图的需要)来规避这个问题。然而,这实际上 * 增加了 * 相同代码的运行时间!我是不是没听懂torch.func.grad的意思?JAX中的类似实现是否会提供任何性能提升?
1条答案
按热度按时间63lcw9qa1#
也许输出只有6个元素,就像softmax一样,它有6个输入和6个输出,但它的导数是一个形状为6x6的雅可比矩阵,如果你试图得到二阶导数,它将超过6x6x6,它超过了一个三维海森矩阵。
你应该检查它,它的例子,像这样的https://github.com/HIPS/autograd/blob/master/examples/tanh.py
矢量到矢量的梯度将是矩阵或多维矩阵。所以它所用的时间将呈指数级增长。
或者您可以参考此softmax_https://zhuanlan.zhihu.com/p/657177292