我想我正确地遵循了子类化教程。我有一个非常简单的例子。我只运行一次代码就可以了。当我在Jupyter notebook中重新运行一个单元格时,类会中断,并且它会“忘记”状态(它会记住我添加的内容,它会忘记我对numpy数组所做的转置)。请参见下面的代码。
下面我实现了三个简单的类NamedAxis、NamedAxes和NamedArray(是的,我知道xarray,这是为了我自己的学习目的)。大多数情况下,它工作得很好。然而,当我重新运行flip时,我注意到一些非常令人沮丧的事情
from copy import deepcopy
from dataclasses import dataclass, field
from typing import List, Dict, Union, Optional, Any, Callable, TypeVar, Generic, Type, cast, Tuple
import numpy as np, pandas as pd
@dataclass
class NamedAxis:
# name of axis
name: str
# index of axis
axis: Optional[int] = None
def __str__(self):
return f'{self.name}({self.axis})'
__repr__ = __str__
def copy(self) -> 'NamedAxis':
copy = deepcopy(self)
return copy
@dataclass
class NamedAxes:
axes: Union[List[NamedAxis], Tuple[NamedAxis]]
name: Optional[str] = 'NamedAxes'
umap: Dict[str, NamedAxis] = field(default_factory=dict, init=False, repr=False)
def __post_init__(self):
# assign unique id to each axis
for i, axis in enumerate(self.axes):
axis.axis = i
self.umap = {ax.axis: ax for ax in self.axes}
@property
def ndim(self):
return len(self.axes)
@property
def anames(self):
# names in current location
return [str(ax.name) for ax in self.axes]
@property
def aidxs(self):
# original location as ax.axis should never be changed
return [int(ax.axis) for ax in self.axes]
@property
def alocs(self):
# current location
return list(range(len(self)))
def __getitem__(self, key:Union[int, str, NamedAxis]) -> NamedAxis:
# NOTE: this gets current location of axis, not original location
if isinstance(key, int):
return self.axes[key]
# NOTE: this gets location based off original location
elif isinstance(key, NamedAxis):
return self.umap[key.axis]
# NOTE: this gets location based off original location
elif isinstance(key, str):
for ax in self.umap.values():
if key == ax.name:
return ax
elif key == str(ax.axis):
return ax
else:
raise KeyError(f'Key {key} not found in {self.name}')
def __str__(self):
_str = f'{self.name}(' + ', '.join(self.anames) + ')'
return _str
__repr__ = __str__
def __iter__(self):
return iter(self.axes)
def __len__(self):
return len(self.axes)
def copy(self):
copy = deepcopy(self)
copy.umap = self.umap.copy()
return copy
def index(self, key:Union[int, str, NamedAxis]):
ax = self[key]
return self.axes.index(ax)
def transpose(self, *order:Union[str, int, NamedAxis]):
# check input and convert to axes
update_axes = [self[key] for key in order]
# gather the axes that are not in the provided order
needed_axes = [ax for ax in self.axes if ax not in update_axes]
# the new order of axes is the updated axes followed by the needed axes
new_order = update_axes + needed_axes
print('NamedAxes.transpose:\t', self.name, self.axes, new_order)
# rearrange axes according to the new order
self.axes = new_order
return self
a, b, c = NamedAxis('axis-a'), NamedAxis('axis-b'), NamedAxis('axis-c')
abc = NamedAxes((a, b, c))
abc
class NamedArray(np.ndarray):
DIMS = NamedAxes([NamedAxis('axis-a'), NamedAxis('axis-b'), NamedAxis('axis-c')], name='Trajectories')
def __new__(cls, arr, dims=None):
obj = np.asarray(arr).view(cls)
obj.dims = (dims or cls.DIMS).copy()
return obj
def __new__(cls, arr, dims:NamedAxes=None):
# Input array is an already formed ndarray instance
# We first cast to be our class type
obj = np.asarray(arr).view(cls)
# add the new attribute to the created instance
obj.dims = (dims or cls.DIMS).copy()
# Finally, we must return the newly created object:
return obj
def __array_finalize__(self, obj):
print('finalize, dims=', getattr(obj, 'dims', None))
print('finalize, obj=', obj)
if obj is None: return
self.dims = getattr(obj, 'dims', self.DIMS.copy())
# Ensure the indices are in the correct range
shape = self.shape
if len(shape) != len(self.dims):
raise ValueError('NamedArray must have {len(self.dims)} dimensions, but got {len(shape)}.')
def __array_wrap__(self, out, dims=None):
print('In __array_wrap__:')
print(' self is %s' % repr(self))
print(' arr is %s' % repr(out))
# then just call the parent
return super().__array_wrap__(self, out, dims)
@property
def dim_names(self):
return tuple(self.dims.anames)
@property
def dim_str(self):
_str = ', '.join([f'{s} {n}' for s, n in zip(self.shape, self.dim_names)])
return f'({_str})'
def __repr__(self):
base = super(NamedArray, self).__repr__()
first_line = base.split('\n')[0]
spaces = 0
for s in first_line:
if s.isdigit():
break
spaces += 1
spaces = ' ' * (spaces - 1)
return f'{base}\n{spaces}{self.dim_str}'
def flip(self, axes:Union[str, int, NamedAxis]=None):
# I tried transpose as well
print(self.dims.axes)
# Get the order of axes indices
new_idxs = [self.dims.index(self.dims[ax]) for ax in axes]
print(axes, new_idxs)
# Transpose the NamedAxes
self.dims.transpose(*axes)
print(new_idxs, self.__array_interface__['shape'])
# Transpose the underlying numpy array
self = np.transpose(self, axes=new_idxs)
# self.transpose(*new_idxs)
'''
# NOTE: StackOverflow post edit / clarification
I've tried this a few different ways including
`self.transpose()` as well as just `return np.transpose()`,
and trying to change the function flip to `transpose` etc.
This is just the version I am posting for brevity without
the 10 different `flip` implementations
'''
return self
所以让我们制作一些虚拟数据:
arr = np.random.randint(0, 5, (2, 3, 4))
nar = NamedArray(arr)
nar
# (2 axis-a, 3 axis-b, 4 axis-c)
''' NOTE: flip is basically transpose, with the difference that
`arr.transpose(1, 0, 2).transpose(1, 0, 2)` will do two transposes
but since we are using names and named indices, `nar.flip('b', 'a', 'c').flip('b', 'a', 'c')` should only do one. In other words `flip` is declarative, saying how we want the axes to be. Similar to einops / xarray
'''
nar.flip(('axis-c', 'axis-b', 'axis-a'))
# (4 axis-c, 3 axis-b, 2 axis-a)
到目前为止一切顺利。然而,当我再次运行单元格时
# (2 axis-a, 3 axis-b, 4 axis-c)
nar.flip(('axis-c', 'axis-b', 'axis-a'))
# (2 axis-c, 3 axis-b, 4 axis-a)
我花了太长时间调试这个,我不能弄清楚。
2条答案
按热度按时间9q78igpj1#
在这一行中,创建一个名为self的局部变量。因此变量self(flip的参数)永远不会被修改
转置底层numpy数组
self = np.transpose(self,axes=new_idxs)
2fjabf4q2#