numpy ValueError:invalid __array_struct__ when trying to subclass and initialize ndarray in __new__

goucqfw6  于 2023-10-19  发布在  其他
关注(0)|答案(1)|浏览(92)

我试图子类化np.ndarray,以提供一个只包含Cell示例的专用数组,并能够将getattrsetattr转发给数组中包含的所有单元。
(The Cellparent是另一个类的示例,它持有整体,而Pos只是ndarray的另一个子类,形状总是(2,))。
然而,当尝试示例化这个数组时,我得到了这个错误:

File "models.py", line 91, in __new__
    obj[...] = [ [ Cell(parent, Pos(x, y)) for y in range(h) ] for x in range(w) ]
  File "models.py", line 108, in __setitem__
    return super().__setitem__(k, v)
ValueError: invalid __array_struct__
class CellGrid(np.ndarray):

  def __new__(
      subtype, shape, dtype=Cell, buffer=None, offset=0,
      strides=None, order=None, parent=None
  ):
    if len(shape) != 2 :
      raise RuntimeError('A grid cannot be other than 2-dimensionnal')
    if not issubclass(dtype, Cell) :
      raise RuntimeError('A grid can only hold Cells')
    obj = super().__new__(
      subtype, shape, dtype,
      buffer, offset, strides, order
    )
    if parent is not None :
      h, w = shape
      obj[...] = [ [ Cell(parent, Pos(x, y)) for y in range(h) ] for x in range(w) ] # ERROR
    return obj

  @classmethod
  def create(cls, shape, parent):
    return cls(shape, parent=parent)

  def __getitem__(self, k):
    if isinstance(k, Pos_t) :
      return super().__getitem__((k[0], k[1]))
    else :
      return super().__getitem__(k)

  def __setitem__(self, k, v):
    if isinstance(k, Pos_t) :
      return super().__setitem__((k[0], k[1]), v)
    else :
      return super().__setitem__(k, v) # ERROR

  def __getattr__(self, k):
    return np.vectorize(lambda x: getattr(x, k), object)(self)

  def __setattr__(self, k, v):
    if k not in self.__dict__ :
      np.frompyfunc(lambda x: setattr(x, k, v), nin=1, nout=0)(self)

然而,这很棘手找到有关此“无效__array_struct__“错误的信息.

qpgpyjmq

qpgpyjmq1#

实际上,这与ndarray的子类化无关。
这是因为我在Cell类中重新实现了__getattr__,numpy在dtype类中查找__array*属性,返回None对它无效(它应该引发AttributeError)。
快速解决方法:
如果键以'__array'开头,则通过直接快捷方式到object来启动__setattr____getattr__

def __setattr__(self, key, value):
  if(key.startswith('__array'):
    return object.__setattr__(self, key, value)
  # Your code here...

def __getattr__(self, key):
  if(key.startswith('__array'):
    return object.__getattr__(self, key)
  # Your code here...

相关问题