如何键入hint函数以兼容numpy

ql3eal8s  于 2023-10-19  发布在  其他
关注(0)|答案(1)|浏览(82)

example.py源代码:

from typing import Union, Any
import numpy as np

Number = Union[int, float, np.floating[Any]]

def add_one(num: Number) -> Number:
    return num + 1

inputs = [1, 2, 3]
outputs = [add_one(n) for n in inputs]

avg = np.mean(outputs)

运行mypy:

mypy example.py
src/example.py:14: error: Argument 1 to "mean" has incompatible type "List[Union[float, floating[Any]]]"; expected "Union[_SupportsArray[dtype[Any]], _NestedSequence[_SupportsArray[dtype[Any]]], bool, int, float, complex, str, bytes, _NestedSequence[Union[bool, int, float, complex, str, bytes]]]"
Found 1 error in 1 file (checked 1 source file)

我可以将所有内容都更改为np.floating[Any],这解决了numpy问题,但之后我必须将原语转换为np.float32(...)

from typing import Any
import numpy as np

def add_one(num: np.floating[Any]) -> np.floating[Any]:
    return num + 1

inputs = [1, 2, 3]
outputs = [add_one(np.float32(n)) for n in inputs]

avg = np.mean(outputs)

有没有一种正确的方法来类型提示add_one函数,使其输出与numpy函数(如np.mean)兼容,而不破坏与python基本类型的兼容性?最终目标是能够像这样使用它:

inputs = [1, 2, 3]
outputs = [add_one(n) for n in inputs]
avg = np.mean(outputs)
pb3skfrl

pb3skfrl1#

只使用Number = Union[int, float]不会抛出任何mypy错误。

相关问题