在大型numpy数组中计数零而不创建它们

1szpjjfi  于 2023-06-23  发布在  其他
关注(0)|答案(1)|浏览(107)

为了说明我面临的问题,这里有一些示例代码:

a = np.round(np.random.rand(10, 15))
counta = np.count_nonzero(a, axis=-1)
print(counta)

A = np.einsum('im,mj->ijm', a, a.T)
countA = np.count_nonzero(A, axis=-1)
print(countA)

它创建一个2D数组,并沿着最后一个轴对其非零元素进行计数。然后,它创建一个3D数组,其中的非零元素再次沿着最后一个轴计数。
我的问题是我的数组a太大了,我可以执行第一步,但不能执行第二步,因为A数组会占用太多内存。

有没有办法得到countA?也就是说,在不实际创建数组的情况下,沿着给定的轴计算A中的零?

dkqlctbz

dkqlctbz1#

我认为你可以简单地使用矩阵乘法(点积)来得到你的结果,而不需要生成那个庞大的3D数组A

a = np.round(np.random.rand(10, 15)).astype(int)
counta = np.count_nonzero(a, axis=-1)

A = np.einsum('im,mj->ijm', a, a.T)
countA = np.count_nonzero(A, axis=-1)

assert np.all(countA == (a @ a.T))

这也更快:

a = np.round(np.random.rand(1000, 1500)).astype(int)

%timeit np.count_nonzero(np.einsum('im,mj->ijm', a, a.T), axis=-1)
3.94 s ± 38.9 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

%timeit a @ a.T
558 ms ± 6.74 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

还要注意,第一步与第二步是多余的:

assert np.all(counta == np.diag(countA))

相关问题