类型提示numpy数组和批处理

5kgi1eie  于 2023-02-23  发布在  其他
关注(0)|答案(1)|浏览(134)

我正在尝试为一个科学的python项目创建一些数组类型,到目前为止,我已经为1D、2D和ND numpy数组创建了泛型类型:

from typing import Any, Generic, Protocol, Tuple, TypeVar

import numpy as np
from numpy.typing import _DType, _GenericAlias

Vector = _GenericAlias(np.ndarray, (Tuple[int], _DType))
Matrix = _GenericAlias(np.ndarray, (Tuple[int, int], _DType))
Tensor = _GenericAlias(np.ndarray, (Tuple[int, ...], _DType))

第一个问题是mypy说VectorMatrixTensor不是有效类型(例如,当我尝试myvar: Vector[int] = np.array([1, 2, 3])时)
第二个问题是,我想创建一个泛型类型Batch,我想这样使用它:Batch[Vector[complex]]应该像Matrix[complex]Batch[Matrix[float]]应该像Tensor[float]Batch[Tensor[int]应该像Tensor[int],我不知道我所说的“应该像”是什么意思,我想我的意思是mypy不应该抱怨。
我该怎么做呢?

cwtwac6a

cwtwac6a1#

您不应该从外部使用受保护的成员(名称以下划线开头)。它们通常以这种方式标记,以指示将来可能更改的实现细节,这正是numpy版本之间发生的情况。例如,在1.24中,您的导入行from numpy.typing在运行时失败,因为您尝试导入的成员不再存在。
不需要使用内部别名构造函数,因为numpy.ndarray在数组 * shape * 及其 * dtype * 方面已经是泛型的。您可以相当容易地构造自己的类型别名。您只需要确保正确地参数化 * dtype *。下面是一个工作示例:

from typing import Tuple, TypeVar

import numpy as np

T = TypeVar("T", bound=np.generic, covariant=True)

Vector = np.ndarray[Tuple[int], np.dtype[T]]
Matrix = np.ndarray[Tuple[int, int], np.dtype[T]]
Tensor = np.ndarray[Tuple[int, ...], np.dtype[T]]

用法:

def f(v: Vector[np.complex64]) -> None:
    print(v[0])

def g(m: Matrix[np.float_]) -> None:
    print(m[0])

def h(t: Tensor[np.int32]) -> None:
    print(t.reshape((1, 4)))

f(np.array([0j+1]))  # prints (1+0j)
g(np.array([[3.14, 0.], [1., -1.]]))  # prints [3.14 0.  ]
h(np.array([[3.14, 0.], [1., -1.]]))  # prints [[ 3.14  0.    1.   -1.  ]]

目前的问题是shapes have almost no typing support,但是使用PEP 646提供的新TypeVarTuple功能来实现它的工作正在进行中,在此之前,通过形状区分类型几乎没有实际用途。
批次问题应该是一个单独的问题。尝试一次问一个问题。

相关问题