numpy np.sqrt与整数和其中条件返回错误的结果

wz8daaqr  于 2023-08-05  发布在  其他
关注(0)|答案(2)|浏览(95)

我从numpy sqrt方法中得到奇怪的结果,当将其应用于具有where条件的整数数组时。见下文。
对于整数:

a = np.array([1, 4, 9])

np.sqrt(a, where=(a>5))
Out[3]: array([0. , 0.5, 3. ])

字符串
关于floats:

a = np.array([1., 4., 9.])

np.sqrt(a, where=(a>5))
Out[25]: array([1., 4., 3.])


这是一个bug还是对这个方法的误解?

f87krz0w

f87krz0w1#

我想可能有bug不一致 *,当重复命令几次时,我没有得到一致的结果。

  • 这实际上是由于可能使用了numpy.empty来初始化输出。此函数在不设置默认值的情况下重用内存,因此在单元格中将具有不可预测的值,而这些值不会被手动覆盖。

下面是一个例子:

import numpy as np
print(np.__version__)

a = np.array([1, 4, 9])                   # integers
print(np.sqrt(a, where=(a>5)).round(3))

a = np.array([1., 4., 9.])                # floats
print(np.sqrt(a, where=(a>5)).round(3))

a = np.array([1, 4, 9])                   # integers again
print(np.sqrt(a, where=(a>5)).round(3))

字符串
输出量:

1.24.2
[0. 0. 3.]
[0. 0. 3.]
[1. 4. 3.]  # this is now different


从@hpaulj的评论来看,似乎需要提供out。实际上,这防止了与上述示例不一致的行为。
out应该使用可预测的函数初始化,例如numpy.zeros

import numpy as np
print(np.__version__)

out = np.zeros(3)
a = np.array([1, 4, 9])      # integers
print(np.sqrt(a, where=(a>5), out=out))
print(out)

out = np.zeros(3)
a = np.array([1., 4., 9.])   # floats
print(np.sqrt(a, where=(a>5), out=out))
print(out)

a = np.array([1, 4, 9])      # integers again
print(np.sqrt(a, where=(a>5), out=out))
print(out)


输出量:

1.24.2
[0. 0. 3.]
[0. 0. 3.]
[0. 0. 3.]
[0. 0. 3.]
[0. 0. 3.]
[0. 0. 3.]


尽管如此,这似乎是一种前后矛盾的行为。
doc指定它不提供out,则应分配一个新数组:
输出ndarray、None或ndarray和None的元组,可选
存储结果的位置。如果提供,则它必须具有输入广播到的形状。如果未提供或为“无”,则返回新分配的数组。Tuple(只能作为关键字参数)的长度必须等于输出的数量。
其中array_like,可选
此条件通过输入广播。在条件为True的位置,out数组将被设置为ufunc结果。在其他地方,out数组将保留其原始值。请注意,如果通过默认out=None创建了未初始化的out数组,则其中条件为False的位置将保持未初始化。
在您的情况下,您更希望使用numpy.where

out = np.where((a>5), np.sqrt(a), a)

# or, to avoid potential errors if you have forbidden values
out = np.where(a>5, np.sqrt(a, where=a>5), a)

# or
out = np.sqrt(a, where=a>5, out=a.astype(float))


输出:array([1., 4., 3.])

oxf4rvwz

oxf4rvwz2#

ufuncs似乎有问题,因为默认的out参数将创建空的未初始化数组(这可能是导致此错误的原因)-这里有更详细的描述:"where" clause in numpy-1.13 ufuncs的数据。
来自numpy docs(https://numpy.org/doc/stable/reference/ufuncs.html):
如果'out'为None(默认值),则创建一个未初始化的返回数组。然后,输出数组在广播'where'为True的位置填充ufunc的结果。如果'where'是标量True(默认值),则这对应于填充整个输出。请注意,未显式填充的输出将保留其未初始化的值。
...
请注意,如果创建了一个未初始化的返回数组,False值将使这些值未初始化。
解决方法是添加一个显式的out参数。
举例来说:

>>> a = np.array([1, 4, 9])
>>> np.sqrt(a, where=(a>5))
array([5.e-324, 2.e-323, 3.e+000])

>>> o = np.zeros(3)
>>> np.sqrt(a, where=(a>5), out=o)
array([0., 0., 3.])

>>> a = np.array([1., 4., 9.])
>>> np.sqrt(a, where=(a>5))
array([5.e-324, 2.e-323, 3.e+000])

>>> o = np.zeros(3)
>>> np.sqrt(a, where=(a>5), out=o)
array([0., 0., 3.])

字符串
要获得您(可能)最初想要的东西,numpy.where可能是一个不错的选择:

>>> a = np.array([1, 4, 9])
>>> np.where(a > 5, np.sqrt(a), a)
array([1., 4., 3.])

>>> a = np.array([1., 4., 9.])
>>> np.where(a > 5, np.sqrt(a), a)
array([1., 4., 3.])

相关问题