numpy 仅包含-1或+1的3D数组的数组乘积(einsum)

yshpjwxd  于 12个月前  发布在  其他
关注(0)|答案(2)|浏览(106)

设X是形状为(M,k,g)的数组,Q是形状为(m,k,g)的数组,其中m、M、k和g可以是“非常大”的。假设X和Y的元素是-1或+1。我感兴趣的是由Z[a,b,i] = Q[a,i] @ X[b,i]定义的形状为(m,M,k)的数组Z。很明显,这可以在Numpy(或Jax,用于GPU开发)中完成,如下所示

Z = einsum("mkg,Mkg->mMk", Q, X)

**问题:**通过使用X和Y的元素是-1或+1的信息,可以更有效地计算Z吗?

1aaf6o9v

1aaf6o9v1#

正如我在评论中提到的,转换为int8似乎可以给予一个速度。如果你能够并且愿意使用torch,我看到在使用他们的einsum时会加速(不能详细说明他们使用的是哪些优化),但是如果转换为int8类型,似乎会慢一些:

5cg8jx4n

5cg8jx4n2#

下面是GPU(T4)上的更多基准测试,以补充公认的答案。Numpy vs JAX vs PyTorch
CPU

GPU

结论(基于这个小实验)是:

  • 只有Numpy知道如何充分利用int8。其他人可能在引擎盖下做一些奇怪的铸造。
  • GPU帮助JAX。
  • PyTorch是GPU上的野兽。

相关问题