无法消除关于错误类型numpy.bool_ and bool的mypy错误

utugiqy6  于 2023-03-08  发布在  其他
关注(0)|答案(1)|浏览(117)

我有一个包含几个np.array的类:

class VECMParams(ModelParams):
    def __init__(
        self,
        ecm_gamma: np.ndarray,
        ecm_mu: Optional[np.ndarray],
        ecm_lambda: np.ndarray,
        ecm_beta: np.ndarray,
        intercept_coint: bool,
    ):
        self.ecm_gamma = ecm_gamma 
        self.ecm_mu = ecm_mu
        self.ecm_lambda = ecm_lambda 
        self.ecm_beta = ecm_beta
        self.intercept_coint = intercept_coint

我想重写==操作符,基本上,当所有数组都等于rhs 1时,VECMParam等于另一个:

def __eq__(self, rhs: object) -> bool:
    if not isinstance(rhs, VECMParams):
        raise NotImplementedError()

    return (
        np.all(self.ecm_gamma == rhs.ecm_gamma) and
        np.all(self.ecm_mu == rhs.ecm_mu) and
        np.all(self.ecm_lambda == rhs.ecm_lambda) and
        np.all(self.ecm_beta == rhs.ecm_beta) 
    )

尽管如此,mypy还是一直说Incompatible return value type (got "Union[bool_, bool]", expected "bool") [return-value],因为np.all返回bool_,而__eq__需要返回本机bool。我搜索了几个小时,看起来没有办法将这些bool_转换为本机bool。有人遇到同样的问题吗?
PS:执行my_bool_ is True未计算为正确的本机bool值

fruv7luv

fruv7luv1#

看看numpy.all()

A new boolean or array is returned unless out is specified, in which case a reference to out is returned.

这是Union[ndarray, bool]
如何修复:

def __eq__(self, rhs: 'VECMParams') -> bool:
    if not isinstance(rhs, VECMParams):
        raise NotImplementedError()

    return bool(
        np.all(self.ecm_gamma == rhs.ecm_gamma) and
        np.all(self.ecm_mu == rhs.ecm_mu) and
        np.all(self.ecm_lambda == rhs.ecm_lambda) and
        np.all(self.ecm_beta == rhs.ecm_beta) 
    )

相关问题