带if语句的Numpy数组函数

x8goxv8g  于 2023-01-26  发布在  其他
关注(0)|答案(6)|浏览(125)

我使用MatplotlibNumpy来生成一些图。我希望定义一个函数,给定一个数组返回另一个数组,该数组的值按元素方式计算,例如:

def func(x):
     return x*10

x = numpy.arrange(-1,1,0.01)
y = func(x)

这很好,但是现在我希望在func中有一个if语句,例如:

def func(x):
     if x<0:
          return 0
     else:
          return x*10

x = numpy.arrange(-1,1,0.01)
y = func(x)

不幸的是,这会引发以下错误

Traceback (most recent call last):
  File "D:\Scripts\test.py", line 17, in <module>
    y = func(x)
  File "D:\Scripts\test.py", line 11, in func
    if x<0:
ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()

我看了all()any()的文档,它们不符合我的要求。那么,有没有一种好方法可以让函数像第一个例子那样按元素方式处理数组呢?

6tqwzwtp

6tqwzwtp1#

我知道现在回答这个问题已经太晚了,但是我对学习NumPy很感兴趣。您可以使用numpy.where自己对函数进行矢量化。

def func(x):
    import numpy as np
    x = np.where(x<0, 0., x*10)
    return x
    • 示例**

使用标量作为数据输入:

x = 10
y = func(10)
y = array(100.0)

使用数组作为数据输入:

x = np.arange(-1,1,0.1)
y = func(x)
y = array([ -1.00000000e+00,  -9.00000000e-01,  -8.00000000e-01,
    -7.00000000e-01,  -6.00000000e-01,  -5.00000000e-01,
    -4.00000000e-01,  -3.00000000e-01,  -2.00000000e-01,
    -1.00000000e-01,  -2.22044605e-16,   1.00000000e-01,
     2.00000000e-01,   3.00000000e-01,   4.00000000e-01,
     5.00000000e-01,   6.00000000e-01,   7.00000000e-01,
     8.00000000e-01,   9.00000000e-01])
    • 注意事项**:

1)如果x是掩码数组,则需要使用np.ma.where,因为这适用于掩码数组。

mqxuamgl

mqxuamgl2#

在将func应用于数组x之前,使用numpy.vectorize将其 Package :

from numpy import vectorize
vfunc = vectorize(func)
y = vfunc(x)
u4vypkhs

u4vypkhs3#

这应该可以满足您的需要:

def func(x):
    small_indices = x < 10
    x[small_indices] = 0
    x[invert(small_indices)] *= 10
    return x

invert是一个Numpy函数。注意,这会修改参数。要防止这种情况,您必须修改并返回xcopy

d7v8vwbk

d7v8vwbk4#

  • (我知道这是个老问题,但...)*

还有一个选项在这里没有提到--使用np.choose

np.choose(
    # the boolean condition
    x < 0,
    [
        # index 0: value if condition is False
        10 * x,
        # index 1: value if condition is True
        0
    ]
)

虽然可读性不强,但这只是一个表达式(不是一系列语句),并没有降低numpy固有的速度(如np.vectorize)。

shyt4zoc

shyt4zoc5#

x = numpy.arrange(-1,1,0.01)
mask = x>=0
y = numpy.zeros(len(x))
y[mask] = x[mask]*10

mask是一个布尔数组,它等于True是数组索引,匹配条件,False在其他地方。最后一行用乘以10的值替换原始数组中的所有值。
编辑以反映比约恩的相关评论

xkrw2x1b

xkrw2x1b6#

不确定为什么需要函数

x = np.arange(-1, 1, 0.01)
y = x * np.where(x < 0, 0, 10)

相关问题