针对if语句中的特定值测试numpy数组

qlvxas9a  于 2023-08-05  发布在  其他
关注(0)|答案(2)|浏览(72)

我有一个numpy数组x = np.linspace(0, 10, 11)。我想返回数组中介于 ab 之间的值,然后返回其他值之间的值。
我写道:

import numpy as np
x = np.linspace(0, 10, 11)
if 2 <= x < 7:
    return x**2
else:
    return x**3

字符串
我收到一条错误消息:ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()
然而,这两个函数针对整个数组进行测试并返回所有元素。我也尝试了np.where()np.logical_and()函数,得到了同样的错误。

x7rlezfr

x7rlezfr1#

np.where((x>=2)&(x<7), x**2, x**3)

字符串
就是你要找的东西。
np.where(最多)采用3个参数。一个条件(因此是一个布尔值数组),以及两个根据条件选择的值。
请注意,您的程式码中有3个问题。
1.如果x是一个numpy数组,你就不能用if x<7x<7是一个布尔值数组。和和数组的布尔值既不为真也不为假。这就是你得到的错误信息的含义。
1.语法2<=x<7不可能与数组一起使用。2<=x是一个布尔值数组。x<7是另一个布尔数组。您可以使用运算符&对两个布尔值数组进行and运算。但2<=x<7不是布尔值数组。因为在现实中同样的原因:2<=x<7(2<=x) and (x<7)的Python快捷方式。如果x是一个标量,那么这是有意义的,因此2<=xx<7都是布尔值(或者任何具有“真值”的值)。但如果2<=xx<7是数组,就不是了(因此,对于这类事情,还有另一个运算符&)
1.即使它在工作,即使if 2<=x<7:...有意义,你会从那里做什么。这是一个单一的回报。这不是量子力学:则结果是return x**2或者是return x**3。不可能两者都是。因此,无论如何,逻辑都不起作用。不可能这么简单。您要返回的既不是x**2,也不是x**3,而是由两者组成的新数组。
嗯,我说“不可能那么简单”,但正如你所看到的,正确的答案也很简单。但它依赖于numpy函数,而不是纯粹的python if/else

另一种方法

这不是唯一的 numpy 方式。
例如,您还可以

ret=np.empty_like(x)
cond=(x>=2)&(x<7)
ret[cond]=x[cond]**2
ret[~cont]=x[~cond]**2
return ret


该方法的一个优点是,它不会在任何地方计算x**2x**3,即使在不需要的地方也是如此(而np.where方法会计算所有数字的**2**3,并针对每种情况选择2个中的一个)。
缺点是它需要几个索引操作。在这种情况下,由于索引操作的扩展性并不比**2**3小,因此速度并不快。在这个例子中,不管x的大小如何,np.where方法在我的机器上的速度要快30%,即使它做了无用的计算。
但是在一个更复杂的情况下,比如说

def whereBased(x):
    return np.where((x>=2)&(x<7), np.sin(2*np.sqrt(np.exp(x))*np.pi), np.cos(2*np.sqrt(np.exp(x))*np.pi))


因此,如果运算的开销比**2稍大,则值得尝试避免在所有x上同时计算sin和cos,并放弃一半的计算。

def indexBased(x):
    ret=np.empty_like(x)
    cond=(x>=2)&(x<7)
    ret[cond] = np.sin(2*np.sqrt(np.exp(x[cond]))*np.pi)
    ret[~cond] = np.cos(2*np.sqrt(np.exp(x[~cond]))*np.pi)
    return ret


这一次,情况正好相反。基于索引的解决方案的速度提高了30%。
这仍然是一个非常简单的例子(每个数字只有2次乘法,一个sin/cos,一个exp和一个sqrt)。
您的实际计算越复杂(我认为您的真实的用例不是x**2x**3),这种只执行所需计算的解决方案就越有价值。

ijnw1ujt

ijnw1ujt2#

比较2 <= x < 7无效。这里比较的是NumPy数组,而不是离散值。
您可以使用NumPywhere方法:

x = np.linspace(0, 10, 11)

result = np.where((2 <= x) & (x < 7), x**2, x**3)

字符串
其输出:

array([   0.,    1.,    4.,    9.,   16.,   25.,   36.,  343.,  512.,
        729., 1000.])

相关问题