如果我创建一个包含Numpy ndarray的Python类,我就不能再使用自动生成的__eq__
了。
import numpy as np
@dataclass
class Instr:
foo: np.ndarray
bar: np.ndarray
arr = np.array([1])
arr2 = np.array([1, 2])
print(Instr(arr, arr) == Instr(arr2, arr2))
ValueError:具有多个元素的数组的真值是不明确的。使用.any()或a.all()
这是因为ndarray.__eq__
* 有时 * 会返回ndarray
的真值,通过比较a[0]
和b[0]
,以此类推,直到2中的较长者。这是相当复杂和不直观的,事实上,只有当数组的形状不同,或具有不同的值或其他东西时才会引发错误。
如何安全地比较持有Numpy数组的@dataclass
es?@dataclass
的__eq__
实现是使用eval()
生成的。它的源代码从堆栈跟踪中丢失,并且不能使用inspect
查看,但它实际上使用了一个 * 元组比较 *,调用了bool(foo)。
import dis
dis.dis(Instr.__eq__)
摘录:
3 12 LOAD_FAST 0 (self)
14 LOAD_ATTR 1 (foo)
16 LOAD_FAST 0 (self)
18 LOAD_ATTR 2 (bar)
20 BUILD_TUPLE 2
22 LOAD_FAST 1 (other)
24 LOAD_ATTR 1 (foo)
26 LOAD_FAST 1 (other)
28 LOAD_ATTR 2 (bar)
30 BUILD_TUPLE 2
32 COMPARE_OP 2 (==)
34 RETURN_VALUE
2条答案
按热度按时间pxq42qpu1#
解决方案是放入您自己的
__eq__
方法并设置eq=False
,这样类就不会生成自己的方法(尽管检查文档,最后一步是不必要的,但我认为无论如何显式显示是很好的)。编辑
一个通用的快速解决方案,用于一些值是numpy数组而另一些不是的通用类
xtupzzrd2#
如果您使用attrs而不是xmlasses,则可以自定义: