numpy 聚合计算vmap

tkqqtvp1  于 2023-10-19  发布在  其他
关注(0)|答案(2)|浏览(127)

我试图实现快速例程来计算能量数组,并找到最小的计算值及其索引。下面是我的代码,它工作得很好:

@jit
def findMinEnergy(x):
  def calcEnergy(a):
    return a*a  # very simplified body, it is actually 15 lines of code
  energies = vmap(calcEnergy, in_axes=(0))(x)
  idx = energies.argmin(axis=0)
  minenrgy = energies[idx]
  return idx, minenrgy

我想知道是否有可能不使用(单独的)argmin调用,而是从vmap返回min计算的能量值和它的索引(类似于其他聚合函数的工作,例如。我希望它能更有效率。

bgibtngc

bgibtngc1#

如果您使用JIT编译当前的方法,您应该会发现它与执行更复杂的操作一样高效。
看看argmin的实现,你会看到它在只返回索引之前计算了值和索引:https://github.com/google/jax/blob/jax-v0.4.18/jax/_src/lax/lax.py#L3892-L3914
如果你愿意,你可以遵循这个实现,并使用lax.reduce定义一个函数,在一次传递中返回这两个值:

import jax
import jax.numpy as jnp

@jax.jit
def min_and_argmin_onepass(x):
  # This only works for 1D float arrays, but you could generalize it.
  assert x.ndim == 1
  assert jnp.issubdtype(x.dtype, jnp.floating)
  def reducer(op_val_index, acc_val_index):
    op_val, op_index = op_val_index
    acc_val, acc_index = acc_val_index
    pick_op_val = (op_val < acc_val) | jnp.isnan(op_val)
    pick_op_index = pick_op_val | ((op_val == acc_val) & (op_index < acc_index))
    return (jnp.where(pick_op_val, op_val, acc_val),
            jnp.where(pick_op_index, op_index, acc_index))
  indices = jnp.arange(len(x))
  return jax.lax.reduce((x, indices), (jnp.inf, 0), reducer, (0,))

测试这个,我们看到它与不太复杂的方法的输出相匹配:

@jax.jit
def min_and_argmin(x):
  i = jnp.argmin(x)
  return x[i], i

x = jax.random.uniform(jax.random.key(0), (1000000,))
print(min_and_argmin_onepass(x))
# (Array(9.536743e-07, dtype=float32), Array(24430, dtype=int32))
print(min_and_argmin(x))
# (Array(9.536743e-07, dtype=float32), Array(24430, dtype=int32))

如果你比较这两者的运行时,你会看到相当的运行时:

%timeit jax.block_until_ready(min_and_argmin_onepass(x))
# 2.17 ms ± 68.7 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
%timeit jax.block_until_ready(min_and_argmin(x))
# 2.07 ms ± 66.9 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

这里的jax.jit装饰器意味着编译器以不太复杂的方法优化了操作序列,结果是您无法从更巧妙地表达事物中获得太多优势。鉴于此,我认为最好的选择是坚持使用原始代码,而不是试图优化XLA编译器。

brvekthn

brvekthn2#

假设高效意味着不必在内存中保留一个大数组(energies),那么只需将idxminenergy的各个值堆叠到calcEnergy中的单个数组中,并将(2,)数组返回给vmap,而不是(N,)数组。这并不漂亮,因为你(大概)必须将两个值转换为相同的dtype,但它应该可以正常工作。

相关问题