如何使用__array_ufunc__覆盖numpy ufunc

q5lcpyga  于 2023-03-30  发布在  其他
关注(0)|答案(2)|浏览(99)

我正在尝试实现numpy的ufunc来处理一个类,使用numpy v1.13中引入的__array_ufunc__方法。
为了简化,类可能看起来像这样:

class toto():
    def __init__(self, value, name):
        self.value = value
        self.name = name
    def __add__(self, other):
    """add values and concatenate names"""
        return toto(self.value + other.value, self.name + other.name)
    def __sub__(self, other):
    """sub values and concatenate names"""
        return toto(self.value - other.value, self.name + other.name)  

tata = toto(5, "first")  
titi = toto(1, "second")

现在,如果我尝试在这两者之间应用np.add,我会得到预期的结果,因为np.add依赖于add。但是如果我调用say np.exp,我会得到预期的错误:

>>> np.exp(tata)
AttributeError: 'toto' object has no attribute 'exp'

现在我想做的是“覆盖”所有numpy ufuncs,以便在这个类中顺利工作,而不必重新定义类中的每个方法(exp(self),log(self),...)。
我计划使用numpy ufunc的[array_ufunc]1来实现这一点,但我并不真正理解文档,因为它没有提供一个简单的实现示例。
如果有人对这种看起来很有前途的新功能有任何经验,你能提供一个简单的例子吗?

0qx6xfy6

0qx6xfy61#

如果我用__array_ufunc__方法(和__repr__)扩展你的类:

class toto():
    def __init__(self, value, name):
        self.value = value
        self.name = name
    def __add__(self, other):
        """add values and concatenate names"""
        return toto(self.value + other.value, self.name + other.name)
    def __sub__(self, other):
        """sub values and concatenate names"""
        return toto(self.value - other.value, self.name + other.name)

    def __repr__(self):
        return f"toto: {self.value}, {self.name}"
    def __array_ufunc__(self, *args, **kwargs):
        print(args)
        print(kwargs)

尝试一些ufunc调用:

In [458]: np.exp(tata)                                                          
(<ufunc 'exp'>, '__call__', toto: 5, first)
{}
In [459]: np.exp.reduce(tata)                                                   
(<ufunc 'exp'>, 'reduce', toto: 5, first)
{}
In [460]: np.multiply.reduce(tata)                                              
(<ufunc 'multiply'>, 'reduce', toto: 5, first)
{}
In [461]: np.exp.reduce(tata,axes=(1,2))                                        
(<ufunc 'exp'>, 'reduce', toto: 5, first)
{'axes': (1, 2)}
In [463]: np.exp.reduce(tata,axes=(1,2),out=np.arange(3))                       
(<ufunc 'exp'>, 'reduce', toto: 5, first)
{'axes': (1, 2), 'out': (array([0, 1, 2]),)}

这显示了类接收到的信息,显然你可以做你想做的事情,它可以返回NotImplemented,我想在你的例子中,它可以把第一个参数应用到你的self.value,或者做一些自定义的计算。
例如,如果我添加

val = args[0].__call__(self.value) 
      return toto(val, self.name)

我得到:

In [468]: np.exp(tata)                                                          
(<ufunc 'exp'>, '__call__', toto: 5, first)
{}
Out[468]: toto: 148.4131591025766, first
In [469]: np.sin(tata)                                                          
(<ufunc 'sin'>, '__call__', toto: 5, first)
{}
Out[469]: toto: -0.9589242746631385, first

但是,如果我把对象放在数组中,我仍然会得到方法错误

In [492]: np.exp(np.array(tata))                                                
---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
<ipython-input-492-4dc37eb0ffe3> in <module>
----> 1 np.exp(np.array(tata))

AttributeError: 'toto' object has no attribute 'exp'

显然,ufunc在一个对象dtype数组上迭代数组的元素,期望使用一个'相关'方法。对于np.add(+),它会查找__add__方法。对于np.exp,它会查找exp方法。这个__array_ufunc__不会被调用。
所以它看起来更像是ndarray的子类,或者类似的东西。我想,你正在尝试实现一个可以作为对象dtype数组元素的类。

hxzsmxv2

hxzsmxv22#

我认为您缺少__array_function__协议,https://numpy.org/neps/nep-0018-array-function-protocol.html
__array_ufunc__将只与某些numpy ufunc一起工作,但不是所有的。当不可用时,numpy将使用__array_function__协议进行调度,https://numpy.org/devdocs/release/1.17.0-notes.html#numpy-functions-now-always-support-overrides-with-array-function
下面是一个简单的例子:

import numpy as np
import logging
import inspect

HANDLED_FUNCTIONS = {}

def implements(numpy_function):
    """Register an __array_function__ implementation for MyArray objects."""
    def decorator(func):
        HANDLED_FUNCTIONS[numpy_function] = func
        return func
    return decorator

class MyArray(object):
    def __array_function__(self, func, types, args, kwargs):
        logging.debug('{} {}'.format(inspect.currentframe().f_code.co_name, func))
        if func not in HANDLED_FUNCTIONS:
            return NotImplemented
        if not all(issubclass(t, MyArray) for t in types):
            return NotImplemented
        return HANDLED_FUNCTIONS[func](*args, **kwargs)
    
    def __array_ufunc__(self, ufunc, method, inputs, *args, **kwargs):
        logging.debug('{} {}'.format(inspect.currentframe().f_code.co_name, ufunc))
        if ufunc not in HANDLED_FUNCTIONS:
            return NotImplemented
        out = kwargs.pop('out', None)
    
        if out is not None:
            HANDLED_FUNCTIONS[ufunc](inputs, *args, out=out[0], **kwargs)
            return
        else:
            return HANDLED_FUNCTIONS[ufunc](inputs, *args, out=None, **kwargs)
        
    def __init__(self, inlist):
        self.list = inlist[:]        
    
    @property
    def ndim(self):
        return 1

    @property
    def shape(self):
        return (len(self.list), )
    
    @property
    def dtype(self):
        return np.dtype(np.int32)

    def __str__(self):
        return "MyArray " + str(self.list)
    
    def __add__(self, other, *args, **kwargs):
        logging.debug('{}'.format(inspect.currentframe().f_code.co_name))
        return self.add(other, *args, **kwargs)

    @implements(np.add)
    def add(self, *args, **kwargs):
        strng = "MyClass add, out {} {}".format( kwargs.get('out', None), len(args) )
        logging.debug('{} {}'.format(inspect.currentframe().f_code.co_name, strng))
        out = kwargs.get('out', None)
        if out is None:
            return MyArray([el + args[0] for el in self.list])
        else:
            for i,el in enumerate(self.list):
                out[i] = args[0] + el

    # implements np.sum is required when one wants to use the np.sum on this object            
    @implements(np.sum)
    def sum(self, *args, **kwargs):
        return sum(self.list)  # return self.list.ndim    

def main():
    logging.basicConfig(level=logging.DEBUG)

    A = MyArray(np.array([1,2]))
    
    # test np.sum
    print ("sum" , np.sum(A, axis=1))
    
    # test add
    B = A.add(2)
    printit(B, 'B')
    
    out = MyArray([20,30])
    printit(out,'out')
    A.add(2,out=out)
    printit(out,'out')

    # test np.add
    # see comments on __add__ 
    #B = A+2
    B = np.add(A,2)
    printit(B, 'B')

    B = A+2
    printit(B, 'B')

    np.add(A,2,out=out)
    printit(out, "out")

相关问题