如何在NumPy ndarrays中支持修改后的数据解释?

l0oc07j2  于 2023-08-05  发布在  其他
关注(0)|答案(1)|浏览(77)

我正在尝试编写一个Python 3类,将一些数据存储在NumPy np.ndarray中。但是,我希望我的类还包含一条关于如何解释数据值的信息。
例如,我们假设ndarraydtypenp.float32,但还有一个“color”修改了这些浮点值的含义。因此,如果我想添加一个red数字和一个blue数字,我必须首先将这两个数字转换为洋红,以便合法地添加它们的底层_data数组。加法的结果将是_color = "magenta"
这只是一个玩具的例子。实际上,“颜色”不是一个字符串(最好将其视为整数),结果的“颜色”是从两个输入的“颜色”中数学确定的,并且任何两个“颜色”之间的转换都是数学定义的。

class MyClass:
    
    def __init__(self, data : np.ndarray, color : str):
        self._data = data
        self._color = color
    
    
    # Example: Adding red numbers and blue numbers produces magenta numbers
    def convert(self, other_color):
        if self._color == "red" and other_color == "blue":
            return MyClass(10*self._data, "magenta")
        elif self._color == "blue" and other_color == "red":
            return MyClass(self._data/10, "magenta")
    
    
    def __add__(self, other):
        if other._color == self._color:
            # If the colors match, then just add the data values
            return MyClass(self._data + other._data, self._color)
        else:
            # If the colors don't match, then convert to the output color before adding
            new_self = self.convert(other._color)
            new_other = other.convert(self._color)
            return new_self + new_other

字符串
我的问题是_color信息与_data共存。因此,我似乎无法为我的类定义合理的索引行为:

  • 如果我将__getitem__定义为返回self._data[i],那么_color信息将丢失。
  • 如果我定义__getitem__返回MyClass(self._data[i], self._color),那么我创建了一个包含标量数的新对象。这将导致大量的问题(例如,我可以合法地索引that_object[i],导致某些错误。
  • 如果我定义了__getitem__来返回MyClass(self._data[i:i+1], self._color),那么我就是在索引一个数组来得到一个数组,这会导致很多其他的问题。例如,my_object[i] = my_object[i]看起来很合理,但会抛出错误。

然后我开始想,我真正想要的是每种不同的“颜色”都有一个不同的dtype。这样,索引值将在dtype...中免费编码“颜色”信息。但我不知道如何实现
“颜色”的理论总数可能大约为100,000。但是,在任何单个脚本执行中使用的数量将少于100个。所以,我想可能会维护一个列表/字典/?使用的“颜色”以及它们如何Map到动态生成的类…但是Python倾向于以我意想不到的方式悄悄地转换类型,所以这可能不是正确的方法。
我所知道的是,我不想把“颜色”和每个数据值一起存储。数据阵列可以是~数十亿个条目,所有条目具有一种“颜色”。
我怎样才能在拥有一个可用类的同时跟踪这个“颜色”信息呢?

iezvtpos

iezvtpos1#

定义每个dunder(__add__等)是站不住脚的。从np.ndarray继承可能也是站不住脚的,这是兼容的派生类所需要的。
你可以用一个薄薄的 Package 纸:

from typing import NamedTuple, Sequence, Any, Callable

import numpy as np

Colour = int
RED: Colour = 0x0000FF
MAGENTA: Colour = 0xFF00FF
BLUE: Colour = 0xFF0000

def common_colour(colours: Sequence[Colour]) -> Colour:
    # magic happens here
    return sum(colours)

class ColouredArray(NamedTuple):
    colour: Colour
    data: np.ndarray

    def __str__(self) -> str:
        return f'({self.colour}) {self.data}'

    def convert(self, new_colour: Colour) -> 'ColouredArray':
        return ColouredArray(
            # magic happens here
            data=self.data * new_colour/self.colour,
            colour=new_colour,
        )

def all_common(arrays: Sequence['ColouredArray']) -> tuple['ColouredArray']:
    new_colour = common_colour([a.colour for a in arrays])
    return tuple(
        array.convert(new_colour) for array in arrays
    )

def call_common(method: Callable, *args, **kwargs) -> tuple[Colour, Any]:
    new_colour = common_colour([
        arg.colour
        for arg in (*args, *kwargs.values())
        if isinstance(arg, ColouredArray)
    ])
    return new_colour, method(
        *(
            arg.convert(new_colour).data if isinstance(arg, ColouredArray) else arg
            for arg in args
        ),
        **{
            k: arg.convert(new_colour).data if isinstance(arg, ColouredArray) else arg
            for k, arg in kwargs.items()
        },
    )

y = ColouredArray(*call_common(
    np.interp,
    x=ColouredArray(RED, np.arange(5)),
    xp=ColouredArray(RED, np.arange(1, 30, 2)),
    fp=ColouredArray(BLUE, np.arange(11, 40, 2)),
    left=-1,
))
print(y)

字符串
在这个例子中,来自call_common的元组被直接解压缩到ColouredArray构造函数中,因为interp返回一个np.ndarray。在其他情况下,例如从numpy.linalg.lstsq返回的4元组,将由调用者根据需要解包并重建彩色数组。

相关问题