numpy 在循环中计算范数会减慢Dask的计算速度

nwlls2ji  于 2023-05-07  发布在  其他
关注(0)|答案(1)|浏览(144)

我试图使用Dask实现一个共轭梯度算法(出于教学目的),当我意识到性能比一个简单的numpy实现差得多。经过几次实验,我已经能够将问题简化为以下片段:

import numpy as np
import dask.array as da
from time import time

def test_operator(f, test_vector, library=np):
    for n in (10, 20, 30):
        v = test_vector()

        start_time = time()
        for i in range(n):
            v = f(v)
            k = library.linalg.norm(v)
    
            try:
                k = k.compute()
            except AttributeError:
                pass
            print(k)
        end_time = time()

        print('Time for {} iterations: {}'.format(n, end_time - start_time))

print('NUMPY!')
test_operator(
    lambda x: x + x,
    lambda: np.random.rand(4_000, 4_000)
)

print('DASK!')
test_operator(
    lambda x: x + x,
    lambda: da.from_array(np.random.rand(4_000, 4_000), chunks=(2_000, 2_000)),
    da
)

在代码中,我简单地将一个向量乘以2(这就是f所做的)并打印其范数。当使用dask运行时,每次迭代的速度都会慢一些。如果我不计算kv的范数,这个问题就不会发生。
不幸的是,在我的例子中,k是我用来停止共轭梯度算法的残差的范数。如何避免这个问题?为什么会这样?
谢谢大家!

1u4esq0p

1u4esq0p1#

我认为这段代码在dask中没有使用惰性求值,特别是加法运算。在没有优化的情况下,添加lambda x: x+x会使执行图复杂化,深度随着计数器的增加而增加,因此会产生开销。更准确地说,对于计数器值i,我们在计算范数时处理O(i)的图,因此总运行时间为O(n**2)。当然,优化是可能的和期望的,但我在这里停止,因为共享的示例是合成的。下面我将演示图形随计数器线性增长。

要直观地看到二次复杂度,请考虑下面的代码片段的清理版本

import numpy as np
import dask.array as da
from time import time
import matplotlib.pyplot as plt

ns = (10, 20, 40, 50, 60)

def test_operator(f, v, norm):
  out = []
  for n in ns:
    start_time = time()
    for i in range(n):
      v = f(v)
      norm(v)
    end_time = time()
    out.append(end_time - start_time)
  return out

out = test_operator(
    lambda x:x+x,
    np.random.rand(4_000, 4_000),
    norm = np.linalg.norm
)
plt.scatter(ns,out,label='numpy')

out = test_operator(
    lambda x:x+x,
    da.from_array(np.random.rand(4_000, 4_000), chunks=(2_000, 2_000)),
    norm = lambda v: da.linalg.norm(v).compute()
)

plt.scatter(ns,out,label='dask')

plt.legend()
plt.show()

相关问题