如何比较持有numpy.ndarray(bool(a==b)引发ValueError)的类的相等性?

carvr3hs  于 2023-10-19  发布在  其他
关注(0)|答案(2)|浏览(114)

如果我创建一个包含Numpy ndarray的Python类,我就不能再使用自动生成的__eq__了。

import numpy as np

@dataclass
class Instr:
    foo: np.ndarray
    bar: np.ndarray

arr = np.array([1])
arr2 = np.array([1, 2])
print(Instr(arr, arr) == Instr(arr2, arr2))

ValueError:具有多个元素的数组的真值是不明确的。使用.any()或a.all()
这是因为ndarray.__eq__ * 有时 * 会返回ndarray的真值,通过比较a[0]b[0],以此类推,直到2中的较长者。这是相当复杂和不直观的,事实上,只有当数组的形状不同,或具有不同的值或其他东西时才会引发错误。
如何安全地比较持有Numpy数组的@dataclass es?
@dataclass__eq__实现是使用eval()生成的。它的源代码从堆栈跟踪中丢失,并且不能使用inspect查看,但它实际上使用了一个 * 元组比较 *,调用了bool(foo)。

import dis
dis.dis(Instr.__eq__)

摘录:

3          12 LOAD_FAST                0 (self)
             14 LOAD_ATTR                1 (foo)
             16 LOAD_FAST                0 (self)
             18 LOAD_ATTR                2 (bar)
             20 BUILD_TUPLE              2
             22 LOAD_FAST                1 (other)
             24 LOAD_ATTR                1 (foo)
             26 LOAD_FAST                1 (other)
             28 LOAD_ATTR                2 (bar)
             30 BUILD_TUPLE              2
             32 COMPARE_OP               2 (==)
             34 RETURN_VALUE
pxq42qpu

pxq42qpu1#

解决方案是放入您自己的__eq__方法并设置eq=False,这样类就不会生成自己的方法(尽管检查文档,最后一步是不必要的,但我认为无论如何显式显示是很好的)。

import numpy as np

def array_eq(arr1, arr2):
    return (isinstance(arr1, np.ndarray) and
            isinstance(arr2, np.ndarray) and
            arr1.shape == arr2.shape and
            (arr1 == arr2).all())

@dataclass(eq=False)
class Instr:

    foo: np.ndarray
    bar: np.ndarray

    def __eq__(self, other):
        if not isinstance(other, Instr):
            return NotImplemented
        return array_eq(self.foo, other.foo) and array_eq(self.bar, other.bar)

编辑

一个通用的快速解决方案,用于一些值是numpy数组而另一些不是的通用类

import numpy as np
from dataclasses import dataclass, astuple

def array_safe_eq(a, b) -> bool:
    """Check if a and b are equal, even if they are numpy arrays"""
    if a is b:
        return True
    if isinstance(a, np.ndarray) and isinstance(b, np.ndarray):
        return a.shape == b.shape and (a == b).all()
    try:
        return a == b
    except TypeError:
        return NotImplemented

def dc_eq(dc1, dc2) -> bool:
   """checks if two dataclasses which hold numpy arrays are equal"""
   if dc1 is dc2:
        return True
   if dc1.__class__ is not dc2.__class__:
       return NotImplmeneted  # better than False
   t1 = astuple(dc1)
   t2 = astuple(dc2)
   return all(array_safe_eq(a1, a2) for a1, a2 in zip(t1, t2))

# usage
@dataclass(eq=False)
class T:

   a: int
   b: np.ndarray
   c: np.ndarray

   def __eq__(self, other):
        return dc_eq(self, other)
xtupzzrd

xtupzzrd2#

如果您使用attrs而不是xmlasses,则可以自定义:

from attrs import define, field
import numpy

@define
class C:
   an_array = field(eq=attr.cmp_using(eq=numpy.array_equal))

相关问题