numpy 在下面的教程中子类化ndarray会产生意想不到的结果(即部分记忆,一些属性被记住,其他属性丢失)

7fyelxc5  于 2023-06-23  发布在  其他
关注(0)|答案(2)|浏览(153)

我想我正确地遵循了子类化教程。我有一个非常简单的例子。我只运行一次代码就可以了。当我在Jupyter notebook中重新运行一个单元格时,类会中断,并且它会“忘记”状态(它会记住我添加的内容,它会忘记我对numpy数组所做的转置)。请参见下面的代码。
下面我实现了三个简单的类NamedAxis、NamedAxes和NamedArray(是的,我知道xarray,这是为了我自己的学习目的)。大多数情况下,它工作得很好。然而,当我重新运行flip时,我注意到一些非常令人沮丧的事情

from copy import deepcopy
from dataclasses import dataclass, field
from typing import List, Dict, Union, Optional, Any, Callable, TypeVar, Generic, Type, cast, Tuple
import numpy as np, pandas as pd

@dataclass
class NamedAxis:
    # name of axis
    name: str
    # index of axis
    axis: Optional[int] = None

    def __str__(self):
        return f'{self.name}({self.axis})'
    
    __repr__ = __str__

    def copy(self) -> 'NamedAxis':
        copy = deepcopy(self)
        return copy

    
@dataclass    
class NamedAxes:
    axes: Union[List[NamedAxis], Tuple[NamedAxis]]
    name: Optional[str] = 'NamedAxes'
    umap: Dict[str, NamedAxis] = field(default_factory=dict, init=False, repr=False)
    
    def __post_init__(self):
        # assign unique id to each axis
        for i, axis in enumerate(self.axes):
            axis.axis = i
        
        self.umap = {ax.axis: ax for ax in self.axes}

    @property
    def ndim(self):
        return len(self.axes)
    
    @property
    def anames(self):
        # names in current location
        return [str(ax.name) for ax in self.axes]
    
    @property
    def aidxs(self):
        # original location as ax.axis should never be changed
        return [int(ax.axis) for ax in self.axes]
    
    @property
    def alocs(self):
        # current location
        return list(range(len(self)))

    def __getitem__(self, key:Union[int, str, NamedAxis]) -> NamedAxis:
        # NOTE: this gets current location of axis, not original location
        if isinstance(key, int):
            return self.axes[key]
        
        # NOTE: this gets location based off original location
        elif isinstance(key, NamedAxis):
            return self.umap[key.axis]

        # NOTE: this gets location based off original location
        elif isinstance(key, str):
            for ax in self.umap.values():
                if key == ax.name:
                    return ax
                
                elif key == str(ax.axis):
                    return ax    
        else:
            raise KeyError(f'Key {key} not found in {self.name}')
        
    def __str__(self):
        _str = f'{self.name}(' + ', '.join(self.anames) + ')'
        return _str
    
    __repr__ = __str__
    
    def __iter__(self):
        return iter(self.axes)

    def __len__(self):
        return len(self.axes)
    
    def copy(self):
        copy = deepcopy(self)
        copy.umap = self.umap.copy()
        return copy
    
    def index(self, key:Union[int, str, NamedAxis]):
        ax = self[key]
        return self.axes.index(ax)

    def transpose(self, *order:Union[str, int, NamedAxis]):
        # check input and convert to axes
        update_axes = [self[key] for key in order]

        # gather the axes that are not in the provided order
        needed_axes = [ax for ax in self.axes if ax not in update_axes]
        
        # the new order of axes is the updated axes followed by the needed axes
        new_order = update_axes + needed_axes
        print('NamedAxes.transpose:\t', self.name, self.axes, new_order)

        # rearrange axes according to the new order
        self.axes = new_order
        return self

a, b, c = NamedAxis('axis-a'), NamedAxis('axis-b'), NamedAxis('axis-c')
abc = NamedAxes((a, b, c))
abc



class NamedArray(np.ndarray):
    DIMS = NamedAxes([NamedAxis('axis-a'), NamedAxis('axis-b'), NamedAxis('axis-c')], name='Trajectories')
    
    def __new__(cls, arr, dims=None):
        obj = np.asarray(arr).view(cls)        
        obj.dims = (dims or cls.DIMS).copy()
        return obj
    
    def __new__(cls, arr, dims:NamedAxes=None):
        # Input array is an already formed ndarray instance
        # We first cast to be our class type
        obj = np.asarray(arr).view(cls)
        # add the new attribute to the created instance
        obj.dims = (dims or cls.DIMS).copy()      
        # Finally, we must return the newly created object:
        return obj
            
    def __array_finalize__(self, obj):
        print('finalize, dims=', getattr(obj, 'dims', None))
        print('finalize, obj=', obj)
        if obj is None: return        
        self.dims = getattr(obj, 'dims', self.DIMS.copy())

        # Ensure the indices are in the correct range
        shape = self.shape
        if len(shape) != len(self.dims):
            raise ValueError('NamedArray must have {len(self.dims)} dimensions, but got {len(shape)}.')
        
    def __array_wrap__(self, out, dims=None):
        print('In __array_wrap__:')
        print('   self is %s' % repr(self))
        print('   arr is %s' % repr(out))
        # then just call the parent
        return super().__array_wrap__(self, out, dims)
    
    
    @property
    def dim_names(self):
        return tuple(self.dims.anames)
            
    @property
    def dim_str(self):
        _str = ', '.join([f'{s} {n}' for s, n in zip(self.shape, self.dim_names)])
        return f'({_str})'
             
    def __repr__(self):
        base = super(NamedArray, self).__repr__()        
        first_line = base.split('\n')[0]
        spaces = 0
        for s in first_line:            
            if s.isdigit():
                break
            spaces += 1
        spaces = ' ' * (spaces - 1)
        return f'{base}\n{spaces}{self.dim_str}'
    

    
    def flip(self, axes:Union[str, int, NamedAxis]=None):
        # I tried transpose as well
        print(self.dims.axes)    
        # Get the order of axes indices        
        new_idxs = [self.dims.index(self.dims[ax]) for ax in axes]
        print(axes, new_idxs)

        # Transpose the NamedAxes
        self.dims.transpose(*axes)        
        print(new_idxs, self.__array_interface__['shape'])
        
        # Transpose the underlying numpy array
        self = np.transpose(self, axes=new_idxs)
        # self.transpose(*new_idxs)
        

        '''
        # NOTE: StackOverflow post edit / clarification
        I've tried this a few different ways including 
        `self.transpose()` as well as just `return np.transpose()`, 
        and trying to change the function flip to `transpose` etc. 
        This is just the version I am posting for brevity without 
        the 10 different `flip` implementations
        '''

        return self

所以让我们制作一些虚拟数据:

arr = np.random.randint(0, 5, (2, 3, 4))
nar = NamedArray(arr)
nar
# (2 axis-a, 3 axis-b, 4 axis-c)

''' NOTE: flip is basically transpose, with the difference that 
`arr.transpose(1, 0, 2).transpose(1, 0, 2)` will do two transposes
but since we are using names and named indices, `nar.flip('b', 'a', 'c').flip('b', 'a', 'c')` should only do one. In other words `flip` is declarative, saying how we want the axes to be. Similar to einops / xarray
'''

nar.flip(('axis-c', 'axis-b', 'axis-a'))
# (4 axis-c, 3 axis-b, 2 axis-a)

到目前为止一切顺利。然而,当我再次运行单元格时

# (2 axis-a, 3 axis-b, 4 axis-c)
nar.flip(('axis-c', 'axis-b', 'axis-a'))
# (2 axis-c, 3 axis-b, 4 axis-a)

我花了太长时间调试这个,我不能弄清楚。

9q78igpj

9q78igpj1#

在这一行中,创建一个名为self的局部变量。因此变量self(flip的参数)永远不会被修改

转置底层numpy数组

self = np.transpose(self,axes=new_idxs)

2fjabf4q

2fjabf4q2#

# try this
def flip(self, axes:Union[str, int, NamedAxis]=None):
        # I tried transpose as well
        print(self.dims.axes)    
        # Get the order of axes indices        
        new_idxs = [self.dims.index(self.dims[ax]) for ax in axes]
        print(axes, new_idxs)

        # Transpose the NamedAxes
        self.dims.transpose(*axes)        
        print(new_idxs, self.__array_interface__['shape'])
        
        # Transpose the underlying numpy array
        ndarr = np.transpose(self, axes=new_idxs)
        self.shape = ndarr.shape
        # self.transpose(*new_idxs)
        

        '''
        # NOTE: StackOverflow post edit / clarification
        I've tried this a few different ways including 
        `self.transpose()` as well as just `return np.transpose()`, 
        and trying to change the function flip to `transpose` etc. 
        This is just the version I am posting for brevity without 
        the 10 different `flip` implementations
        '''

相关问题