我试图实现快速例程来计算能量数组,并找到最小的计算值及其索引。下面是我的代码,它工作得很好:
@jit
def findMinEnergy(x):
def calcEnergy(a):
return a*a # very simplified body, it is actually 15 lines of code
energies = vmap(calcEnergy, in_axes=(0))(x)
idx = energies.argmin(axis=0)
minenrgy = energies[idx]
return idx, minenrgy
我想知道是否有可能不使用(单独的)argmin调用,而是从vmap返回min计算的能量值和它的索引(类似于其他聚合函数的工作,例如。我希望它能更有效率。
2条答案
按热度按时间bgibtngc1#
如果您使用JIT编译当前的方法,您应该会发现它与执行更复杂的操作一样高效。
看看
argmin
的实现,你会看到它在只返回索引之前计算了值和索引:https://github.com/google/jax/blob/jax-v0.4.18/jax/_src/lax/lax.py#L3892-L3914如果你愿意,你可以遵循这个实现,并使用
lax.reduce
定义一个函数,在一次传递中返回这两个值:测试这个,我们看到它与不太复杂的方法的输出相匹配:
如果你比较这两者的运行时,你会看到相当的运行时:
这里的
jax.jit
装饰器意味着编译器以不太复杂的方法优化了操作序列,结果是您无法从更巧妙地表达事物中获得太多优势。鉴于此,我认为最好的选择是坚持使用原始代码,而不是试图优化XLA编译器。brvekthn2#
假设高效意味着不必在内存中保留一个大数组(
energies
),那么只需将idx
和minenergy
的各个值堆叠到calcEnergy
中的单个数组中,并将(2,)数组返回给vmap
,而不是(N,)数组。这并不漂亮,因为你(大概)必须将两个值转换为相同的dtype
,但它应该可以正常工作。