numpy 使自定义点访问dict与np.array()兼容

watbbzwu  于 2023-10-19  发布在  其他
关注(0)|答案(1)|浏览(122)

我有一个自定义的dict类,允许点访问:

class dotDict(dict):
    __getattr__ = dict.__getitem__
    __setattr__ = dict.__setitem__
    __delattr__ = dict.__delitem__

我需要它与numpy.array()兼容。就像这样,我得到错误KeyError: '__array_struct__'。如果我试图在AI工具的帮助下定义一个自定义方法__array_struct__,我总是得到ValueError: invalid __array_struct__
我已经尝试了很多变种沿着

def __array_struct__(self):
        return self.data.__array_interface__["data"], self.data.dtype

但我总是得到ValueError: invalid __array_struct__我也尝试过定义其他方法,如__array__,但没有成功。我发现很难找到关于如何正确执行此操作的文档。
有什么想法吗?
背景:我需要类与np.atleast_1d兼容,它在内部调用asanyarray,这是相当模糊的,但它似乎调用np.array()

vmjh9lq9

vmjh9lq91#

我找到解决办法了!我的班级现在是:

class dotDict(dict):
    def __getattr__(self, item):
        if item == "__array_struct__" or item == "__array_interface__":
            raise AttributeError
        else:
            return self.get(item)

    __setattr__ = dict.__setitem__
    __delattr__ = dict.__delitem__

    def __array__(self):
        return np.array(self.__dict__())

    def __dict__(self):
        out_dict = dict()
        for key, value in self.items():
            out_dict[key] = value
        return out_dict

当尝试将未知对象转换为数组时,numpy首先查找__array_interface____array_struct__。如果这些不存在(我们在这里通过AttributeError显式地告诉numpy,那么它福尔斯到__array__。由于这里的数组表示非常简单(只是一个包含来自我们的自定义dict的数据的数组作为普通dict),这已经足够好了。

相关问题