使用特定形状和数据类型的Numpy Typing

qzlgjiam  于 2023-10-19  发布在  其他
关注(0)|答案(2)|浏览(166)

目前,我正试图更多地使用numpy类型来使我的代码更清晰,但不知何故,我已经达到了目前无法覆盖的限制。
是否可以指定特定的形状和相应的数据类型?范例:

Shape=(4,)
datatype= np.int32

到目前为止,我的尝试看起来像下面这样(但都只是抛出错误):
第一次尝试:

import numpy as np

def foo(x: np.ndarray[(4,), np.dtype[np.int32]]):
...
result -> 'numpy._DTypeMeta' object is not subscriptable

第二次尝试:

import numpy as np
import numpy.typing as npt

def foo(x: npt.NDArray[(4,), np.int32]):
...
result -> Too many arguments for numpy.ndarray[typing.Any, numpy.dtype[+ScalarType]]

另外,不幸的是,我在文档中找不到任何关于它的信息,或者当我按照文档中的方式实现它时,我只会得到错误。

bvjxkvbb

bvjxkvbb1#

目前,numpy.typing.NDArray只接受dtype,如下所示:numpy.typing.NDArray[numpy.int32] .不过你有一些选择。

使用typing.Annotated

typing.Annotated允许您为类型创建别名,并将一些额外的信息与之绑定。
在一些my_types.py中,你可以写出你想要暗示的形状的所有变化:

from typing import Annotated, Literal, TypeVar
import numpy as np
import numpy.typing as npt

DType = TypeVar("DType", bound=np.generic)

Array4 = Annotated[npt.NDArray[DType], Literal[4]]
Array3x3 = Annotated[npt.NDArray[DType], Literal[3, 3]]
ArrayNxNx3 = Annotated[npt.NDArray[DType], Literal["N", "N", 3]]

然后在foo.py中,你可以提供一个numpy dtype并将它们用作typehint:

import numpy as np
from my_types import Array4

def foo(arr: Array4[np.int32]):
    assert arr.shape == (4,)

MyPy将识别arrnp.ndarray,并将其检查为np.ndarray。形状检查只能在运行时完成,如本例中的assert
如果你不喜欢这个Assert,你可以用你的创造力来定义一个函数来为你做检查。

def assert_match(arr, array_type):
    hinted_shape = array_type.__metadata__[0].__args__
    hinted_dtype_type = array_type.__args__[0].__args__[1]
    hinted_dtype = hinted_dtype_type.__args__[0]
    assert np.issubdtype(arr.dtype, hinted_dtype), "DType does not match"
    assert arr.shape == hinted_shape, "Shape does not match"

assert_match(some_array, Array4[np.int32])

使用nptyping

另一个选择是使用第三方库nptyping(是的,我是作者)。
你会放弃my_types.py,因为它不再有用了。
你的foo.py会变成这样:

from nptyping import NDArray, Shape, Int32

def foo(arr: NDArray[Shape["4"], Int32]):
    assert isinstance(arr, NDArray[Shape["4"], Int32])

使用beartype + typing.Annotated

还有另一个名为beartype的第三方库,您可以使用。它可以采用typing.Annotated方法的变体,并将为您执行运行时检查。
您将恢复您的my_types.py,内容类似于:

from beartype import beartype
from beartype.vale import Is
from typing import Annotated
import numpy as np

Int32Array4 = Annotated[np.ndarray, Is[lambda array:
    array.shape == (4,) and np.issubdtype(array.dtype, np.int32)]]
Int32Array3x3 = Annotated[np.ndarray, Is[lambda array:
    array.shape == (3,3) and np.issubdtype(array.dtype, np.int32)]]

你的foo.py会变成:

import numpy as np
from beartype import beartype
from my_types import Int32Array4 

@beartype
def foo(arr: Int32Array4):
    ...  # Runtime type checked by beartype.

使用beartype + nptyping

你也可以把这两个库叠加起来。
您的my_types.py可以再次删除,您的foo.py将变成如下所示:

from nptyping import NDArray, Shape, Int32
from beartype import beartype

@beartype
def foo(arr: NDArray[Shape["4"], Int32]):
    ...  # Runtime type checked by beartype.
cbjzeqam

cbjzeqam2#

我习惯这样做:

def foo(x):
    x = np.array(x, dtype=np.int32)
    if x.shape!=Shape:
        raise ValueError("Shape mismatch")
    #...

如果你有特定的问题与形状,你应该重塑之前,这取决于输入形状,你希望有。如果您需要帮助来正确地重塑输入x,请提供输入x的示例。

相关问题