numpy python numba与复数全局数组

enxuqcxy  于 2023-05-22  发布在  Python
关注(0)|答案(2)|浏览(212)

我想用numba优化我的代码。
我设计的代码包含一个gl.py文件,其中包含一些数组,这些数组将被www.example.com使用main.py,并在main()内部从main.py调用函数。
auxiliary.py看起来像:

import numpy as np
from numba import jit, types

from cmath import sqrt, exp, sin

N_timesteps_imag = 100

N_epsilon_divs = 60
N_z_divs = 2000
K = N_z_divs # equal to the variable N_z_divs

delta_epsilon = 0.1
delta_z = 0.1

lambd = 1.5

z_max = (N_z_divs/2) * delta_z
epsilon_range = np.linspace(0.0, N_epsilon_divs*delta_epsilon, N_epsilon_divs+1)
z_range = np.linspace(-z_max, z_max, N_z_divs+1)

psi_ground = np.zeros((N_z_divs+1, N_epsilon_divs+1, N_timesteps_imag+1), dtype=types.complex128)

@jit(nopython=True)
def pop_psiground_t0():
    for c1 in range(1, psi_ground.shape[0]-1):
        for c2 in range(1, psi_ground.shape[1]-1):
            zed = (c1 - N_z_divs/2) * delta_z 
            epsi = c2 * delta_epsilon
            psi_ground[c1, c2, 0] = sqrt(3) * epsi * exp(-sqrt(epsi**(2*lambd) + zed**2))

pop_psiground_t0()

main.py看起来像(MWE):

import numpy as np
import auxiliary

def main():
    print(auxiliary.psi_ground[1000, 40, 0]) # shall NOT be 0 + 0j !!!

if __name__ == '__main__':
    main()

不管我在www.example.com中为psi_ground的声明输入了什么关键字参数dtypeauxiliary.py,无论是numba.types.complex128np.complex128np.clongdouble,都不起作用。特别是,对于np.complex128,在运行python3 www.example.com时,我得到以下错误main.py:

No implementation of function Function(<built-in function setitem>) found for signature:
 
 >>> setitem(readonly array(complex128, 3d, C), Tuple(int64, int64, Literal[int](0)), complex128)
 
There are 16 candidate implementations:
  - Of which 14 did not match due to:
  Overload of function 'setitem': File: <numerous>: Line N/A.
    With argument(s): '(readonly array(complex128, 3d, C), UniTuple(int64 x 3), complex128)':
   No match.
  - Of which 2 did not match due to:
  Overload in function 'SetItemBuffer.generic': File: numba/core/typing/arraydecl.py: Line 171.
    With argument(s): '(readonly array(complex128, 3d, C), UniTuple(int64 x 3), complex128)':
   Rejected as the implementation raised a specific error:
     TypeError: Cannot modify value of type readonly array(complex128, 3d, C)
  raised from /home/velenos14/.local/lib/python3.8/site-packages/numba/core/typing/arraydecl.py:177

During: typing of setitem at /mnt/c/Users/iusti/Desktop/test_python/auxiliary.py (45)

File "auxiliary.py", line 45:
def pop_psiground_t0():
    <source elided>
            epsi = c2 * delta_epsilon
            psi_ground[c1, c2, 0] = sqrt(3) * epsi * exp(-sqrt(epsi**(2*lambd) + zed**2))

我该如何继续?我试着按照这里写的去做:numba TypingError with complex numpy array and native data types
是的,我需要psi_ground数组是复杂类型的,具有很高的精度,即使最初它是由真实的填充的。在main()中,稍后将重新填充复数。谢谢大家!

8dtrkrch

8dtrkrch1#

psi_ground = np.zeros((N_z_divs+1, N_epsilon_divs+1, N_timesteps_imag+1), dtype=types.complex128)

必须在numba函数中定义。错误明确指出numba无法更改psi_ground数组的值。
下面是修改后的代码

import numpy as np
from numba import jit, types

from cmath import sqrt, exp, sin

N_timesteps_imag = 100

N_epsilon_divs = 60
N_z_divs = 2000
K = N_z_divs # equal to the variable N_z_divs

delta_epsilon = 0.1
delta_z = 0.1

lambd = 1.5

z_max = (N_z_divs/2) * delta_z
epsilon_range = np.linspace(0.0, N_epsilon_divs*delta_epsilon, N_epsilon_divs+1)
z_range = np.linspace(-z_max, z_max, N_z_divs+1)

@jit(nopython=True)
def pop_psiground_t0():
    psi_ground = np.zeros((N_z_divs+1, N_epsilon_divs+1, N_timesteps_imag+1), dtype=types.complex128)
    for c1 in range(1, psi_ground.shape[0]-1):
        for c2 in range(1, psi_ground.shape[1]-1):
            zed = (c1 - N_z_divs/2) * delta_z 
            epsi = c2 * delta_epsilon
            psi_ground[c1, c2, 0] = sqrt(3) * epsi * exp(-sqrt(epsi**(2*lambd) + zed**2))

pop_psiground_t0()
wfauudbj

wfauudbj2#

简单。简单地传入数组(通过ref)作为参数并正确指定类型。更改以下行(不要错过逗号和括号。):

from numba import jit, types, complex128

psi_ground = np.zeros((N_z_divs+1, N_epsilon_divs+1, N_timesteps_imag+1), dtype=np.complex128)

@jit( (complex128[:,:,:],), nopython=True  )
def pop_psiground_t0(psi_ground):

pop_psiground_t0(psi_ground)

相关问题