我正在尝试实现numpy的ufunc来处理一个类,使用numpy v1.13中引入的__array_ufunc__方法。
为了简化,类可能看起来像这样:
class toto():
def __init__(self, value, name):
self.value = value
self.name = name
def __add__(self, other):
"""add values and concatenate names"""
return toto(self.value + other.value, self.name + other.name)
def __sub__(self, other):
"""sub values and concatenate names"""
return toto(self.value - other.value, self.name + other.name)
tata = toto(5, "first")
titi = toto(1, "second")
现在,如果我尝试在这两者之间应用np.add,我会得到预期的结果,因为np.add依赖于add。但是如果我调用say np.exp,我会得到预期的错误:
>>> np.exp(tata)
AttributeError: 'toto' object has no attribute 'exp'
现在我想做的是“覆盖”所有numpy ufuncs,以便在这个类中顺利工作,而不必重新定义类中的每个方法(exp(self),log(self),...)。
我计划使用numpy ufunc的[array_ufunc]1来实现这一点,但我并不真正理解文档,因为它没有提供一个简单的实现示例。
如果有人对这种看起来很有前途的新功能有任何经验,你能提供一个简单的例子吗?
2条答案
按热度按时间0qx6xfy61#
如果我用
__array_ufunc__
方法(和__repr__
)扩展你的类:尝试一些
ufunc
调用:这显示了类接收到的信息,显然你可以做你想做的事情,它可以返回
NotImplemented
,我想在你的例子中,它可以把第一个参数应用到你的self.value
,或者做一些自定义的计算。例如,如果我添加
我得到:
但是,如果我把对象放在数组中,我仍然会得到方法错误
显然,
ufunc
在一个对象dtype数组上迭代数组的元素,期望使用一个'相关'方法。对于np.add
(+),它会查找__add__
方法。对于np.exp
,它会查找exp
方法。这个__array_ufunc__
不会被调用。所以它看起来更像是
ndarray
的子类,或者类似的东西。我想,你正在尝试实现一个可以作为对象dtype数组元素的类。hxzsmxv22#
我认为您缺少
__array_function__
协议,https://numpy.org/neps/nep-0018-array-function-protocol.html__array_ufunc__
将只与某些numpy ufunc一起工作,但不是所有的。当不可用时,numpy将使用__array_function__
协议进行调度,https://numpy.org/devdocs/release/1.17.0-notes.html#numpy-functions-now-always-support-overrides-with-array-function下面是一个简单的例子: