如何避免NumPy中嵌套的for循环?

ctehm74n  于 2022-11-10  发布在  其他
关注(0)|答案(3)|浏览(130)

我有这个密码。

n_nodes = len(data_x)
X = np.zeros((n_nodes, n_nodes))

for i in range(n_nodes):
  for j in range(n_nodes):
    X[i, j] = data_x[i]**j

我想做同样的任务,根本不使用循环。我如何使用NumPy函数来实现这一点?

nxagd54h

nxagd54h1#

我建议你

data_x[:,None]**np.arange(n_nodes)

一张支票

In [17]: data_x = np.array((3,5,7,4,6))
    ...: n_nodes = len(data_x)
    ...: X = np.zeros((n_nodes, n_nodes))
    ...: 
    ...: for i in range(n_nodes):
    ...:   for j in range(n_nodes):
    ...:     X[i, j] = data_x[i]**j
    ...: print(X)
    ...: print('-----------')
    ...: print(data_x[:,None]**np.arange(n_nodes))
[[1.000e+00 3.000e+00 9.000e+00 2.700e+01 8.100e+01]
 [1.000e+00 5.000e+00 2.500e+01 1.250e+02 6.250e+02]
 [1.000e+00 7.000e+00 4.900e+01 3.430e+02 2.401e+03]
 [1.000e+00 4.000e+00 1.600e+01 6.400e+01 2.560e+02]
 [1.000e+00 6.000e+00 3.600e+01 2.160e+02 1.296e+03]]
-----------
[[   1    3    9   27   81]
 [   1    5   25  125  625]
 [   1    7   49  343 2401]
 [   1    4   16   64  256]
 [   1    6   36  216 1296]]

一些时间安排

In [18]: %timeit data_x[:,None]**np.arange(n_nodes)
2.18 µs ± 7.49 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)

In [19]: %%timeit
    ...: for i in range(n_nodes):
    ...:     for j in range(n_nodes):
    ...:         X[i, j] = data_x[i]**j
10.9 µs ± 107 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
wribegjk

wribegjk2#

您可以使用numpy.power.outer一步完成此操作:

np.power.outer(data_x, np.arange(len(data_x)))
oxalkeyp

oxalkeyp3#

如果data_x很大,那么只使用NumPy函数会更快。您可以首先重复输入数组,然后使用np.power和一个给出幂的向量。与已经给出的列表理解版本相比,这应该是完全矢量化的计算。

x = np.arange(10)
X = x[:,np.newaxis].repeat(x.size,axis=1)
X = np.power(X,np.arange(x.size))

如果data_x已经是Numy数组,则可以直接使用它,如果不是,则需要这样做

x = np.array(data_x)

相关问题