numpy 使用ArrayLike时的Mypy错误

mxg2im7a  于 2023-05-17  发布在  其他
关注(0)|答案(2)|浏览(145)

我不明白我应该如何在代码中使用ArrayLike。如果检查mypy,当我尝试在不调用强制转换的情况下使用变量时,我会不断收到错误。我正在尝试定义函数签名,它可以与ndarray以及常规列表一起使用。
例如,下面的代码

import numpy.typing as npt
import numpy as np

from typing import Any

def f(a: npt.ArrayLike) -> int:
    return len(a)

def g(a: npt.ArrayLike) -> Any:
    return a[0]

print(f(np.array([0, 1])), g(np.array([0, 1])))
print(f([0, 1]), g([0, 1]))

给予f()和g()的错误:

Argument 1 to "len" has incompatible type "Union[_SupportsArray[dtype[Any]], _NestedSequence[_SupportsArray[dtype[Any]]], bool, int, float, complex, str, bytes, _NestedSequence[Union[bool, int, float, complex, str, bytes]]]"; expected "Sized"  [arg-type]

Value of type "Union[_SupportsArray[dtype[Any]], _NestedSequence[_SupportsArray[dtype[Any]]], bool, int, float, complex, str, bytes, _NestedSequence[Union[bool, int, float, complex, str, bytes]]]" is not indexable  [index]
b1zrtrql

b1zrtrql1#

numpy.typing.ArrayLike的目的是能够注解
可以强制转换为ndarray的对象。
考虑到这一目的,他们将类型定义为以下联合:

Union[
    _SupportsArray[dtype[Any]],
    _NestedSequence[_SupportsArray[dtype[Any]]],
    bool,
    int,
    float,
    complex,
    str,
    bytes,
    _NestedSequence[Union[bool, int, float, complex, str, bytes]]
]

_SupportsArray只是一个带有__array__方法的协议。它既不需要实现__len__(用于len函数),也不需要实现__getitem__(用于索引)。
_NestedSequence是一个限制性更强的协议,它实际上需要__len____getitem__
但这段代码的问题是参数注解是那个union

import numpy.typing as npt

...

def f(a: npt.ArrayLike) -> int:
    return len(a)

所以a * 可能 * 是一个支持__len__的类似序列的对象,但它 * 也可能 * 只是一个支持__array__的对象。例如,它甚至可以只是一个int(再次参见联合)。因此调用len(a)是不安全的。
类似地,这里的项访问不是类型安全的,因为a可能没有实现__getitem__

...

def g(a: npt.ArrayLike) -> Any:
    return a[0]

所以它对你不起作用的原因是,它不应该被用作numpy数组或其他序列的注解;它的目的是用于可以 * 转换 * 成numpy数组的东西。
如果你想注解你的函数fg以接受列表和numpy数组,你可以只使用listNDArray的并集,比如list[Any] | npt.NDArray[Any]
如果你想有一个更宽的注解来容纳任何有__len____getitem__的类型,你需要定义你自己的protocol

from typing import Any, Protocol, TypeVar

import numpy as np

T = TypeVar("T", covariant=True)

class SequenceLike(Protocol[T]):
    def __len__(self) -> int: ...
    def __getitem__(self, item: int) -> T: ...

def f(a: SequenceLike[Any]) -> int:
    return len(a)

def g(a: SequenceLike[T]) -> T:
    return a[0]

print(f(np.array([0, 1])), g(np.array([0, 1])))
print(f([0, 1]), g([0, 1]))

更准确地说,__getitem__可能也应该接受slice对象,但重载可能对您来说有些过分。

oyt4ldly

oyt4ldly2#

尝试遵循mypy配置下的numpy.typing,然后是建议的settings

[mypy]
plugins = numpy.typing.mypy_plugin

相关问题