当函数包含条件if语句时,如何通过函数传递numpy数组?

yfwxisqw  于 2023-08-05  发布在  其他
关注(0)|答案(1)|浏览(98)

我有以下代码:

import numpy as np
from astropy.cosmology import FlatLambdaCDM
import matplotlib.pyplot as plt

cosmopar = FlatLambdaCDM(H0 = 67.8,Om0 = 0.3)

td_min = 0.1
d = -1

def t_L(z_arb):
    return cosmopar.lookback_time(z_arb).value

def t_d(z_f,z_m):
    return t_L(z_f)-t_L(z_m)

def P_t(z_f,z_m):
    if (td_min<t_d):
        return t_d**d
    else:
        return 0

字符串
现在,如果我定义了一个numpy数组zf_trial1 = np.linspace(0,30,100),并尝试使用命令P_t(zf_trial1,3)将其传递给函数,则函数将返回以下错误语句:
“包含多个元素的数组的真值不明确。使用a.any()或a.all()”
现在我明白了为什么会出现这个错误了--在if语句中使用td_min进行比较时,传递一个包含多个元素的'array'会导致数组中的某些元素满足if语句的条件,而某些元素不满足该条件;但是,我不知道如何解决这个问题。总的来说,我所要做的就是将NumPy数组zf_trial1的每个元素传递到P_t(z_f,z_m)
我尝试了np.vectorize()函数,但效果似乎不太好,结果似乎很混乱,因为当我绘制函数结果时,我收到的图形与我手动输入值到P_t函数中然后绘制它时收到的图形不同。我所尝试的如下:

Pt_vector = np.vectorize(P_t)
Pt_res = Pt_vector(zf_trial1,3)

plt.scatter(zf_trial1,Pt_res)

ldxq2e6h

ldxq2e6h1#

我有点困惑,因为P_tz_fz_m作为参数,但它们根本没有在函数中使用,因此我真的不知道应该如何使用它,因此我将尝试以一般的方式回答这个问题。
你可以在函数中使用np.where来创建一个过滤器,如下所示

def ex(arr,max_value):
      """
      arr: np.array
      max_value: int

      All values in `arr` below `max_value`
      are raised to a power of `d`.
      Values below are set as `0`
      """
      return np.where(arr < max_value, arr**d, 0)

d = 2
ex(np.arange(0,10),5) # array([ 0,  1,  4,  9, 16,  0,  0,  0,  0,  0])

字符串
或者你可以使用list-comprehension:

a = np.array([P_t(v,z_m) for v in zf_trial1])

相关问题