__radd__操作的numpy数字结果为__getitem__循环

von4xj4u  于 2023-05-17  发布在  其他
关注(0)|答案(1)|浏览(145)

我正在实现一个类,它可以处理基本的二进制操作,例如加法,乘法,减法和除法,以及它们各自的变体(反转,就地)。我遇到了一个意想不到的行为,我试图理解。不幸的是,即使通过查看numpy的__add__unsignedinteger实现,我也不能。
要重现此行为,只需运行以下代码:

import numpy as np

class test:
    def __init__(self):
        self.a = 1

    def __len__(self):
        return 1

    def __getitem__(self, index):
        print("getitem")
        return self.a

    def __radd__(self, other):
        return self.a + other

a = test()
b = np.uint8(1) + a

这将导致__getitem__循环。当然,我的实际代码的工作方式有点不同,但仍然面临着完全相同的问题。我还尝试使用python调试器,以便更好地理解实际调用的操作的行为。我想做的主要是,当我运行这段代码时:

b = np.uint8(10) + test()

实际执行__radd__操作。我理解这是因为numpy.unsignedinteger.__add__正在执行。有没有什么pythonic的方法来防止或者更好地修复它?

6psbrbz9

6psbrbz91#

我理解这是因为numpy.unsignedinteger.__add__正在执行。有没有什么pythonic的方法来防止或者更好地修复它?
给定表达式np.uint8(10) + test(),您的test类无法阻止numpy的__add__方法的调用。
只有当numpy在其方法中返回NotImplemented时,才会调用__radd__
你看到的无限__getitem__循环正在发生,因为你的类的行为就像一个无限序列,numpy想把它们的标量添加到这个无限序列的每个元素中。永远不会结束。
你有两种方法来解决这个问题。第一种选择:让你的类成为一个适当的有限序列,让numpy处理操作。这不会调用您的__radd__方法。
要使你的类成为一个有限序列,当索引大于0或小于-1(-1)时,你必须在__getitem__中加入raise IndexError
第二个选择:删除__getitem__方法。然后numpy意识到它不能将自己添加到你的类型中,返回NotImplemented并让你的__radd__方法处理操作。

相关问题