pandas mypy索引带有Enum的pd.DataFrame会引发无重载变量错误

ffx8fchx  于 2023-04-04  发布在  其他
关注(0)|答案(1)|浏览(144)
问题

Mypy给出了“DataFrame”的__getitem__没有重载变体匹配参数类型“MyEnum”错误。在这种情况下,参数类型是Enum,但任何其他自定义类型都会发生这个问题。下面是__get_item__的签名。

def __getitem__(self, Union[str, bytes, date, datetime, timedelta, bool, int, float, complex, Timestamp, Timedelta], /) -> Series[Any]
复制

下面是一个脚本(即 mypy_enum.py),它创建了一个以枚举为列的pandas Dataframe 。

from enum import Enum
import pandas as pd

class MyEnum(Enum):
    TAYYAR = "tayyar"
    HAYDAR = "haydar"

df = pd.DataFrame(data = [[12.2, 10], [8.8, 15], [22.1, 14]], columns=[MyEnum.TAYYAR, MyEnum.HAYDAR])
print(df[MyEnum.TAYYAR])

这是调用它时的输出。它按预期工作,一切正常。

> python mypy_enum.py
0    12.2
1     8.8
2    22.1
Name: MyEnum.TAYYAR, dtype: float64

当你用mypy调用它时;

> mypy mypy_enum.py  
mypy_enum.py:12: error: No overload variant of "__getitem__" of "DataFrame" matches argument type "MyEnum"  [call-overload]
mypy_enum.py:12: note: Possible overload variants:
mypy_enum.py:12: note:     def __getitem__(self, Union[str, bytes, date, datetime, timedelta, bool, int, float, complex, Timestamp, Timedelta], /) -> Series[Any]
mypy_enum.py:12: note:     def __getitem__(self, slice, /) -> DataFrame
mypy_enum.py:12: note:     def [ScalarT] __getitem__(self, Union[Tuple[Any, ...], Series[bool], DataFrame, List[str], List[ScalarT], Index, ndarray[Any, dtype[str_]], ndarray[Any, dtype[bool_]], Sequence[Tuple[Union[str, bytes, date, datetime, timedelta, bool, int, float, complex, Timestamp, Timedelta], ...]]], /) -> DataFrame
Found 1 error in 1 file (checked 1 source file)

难道__getitem__不应该支持列类型本身吗?如何解决这个问题?

js4nwp54

js4nwp541#

此问题是由于pandas-stubs中的此bug引起的。现在已在PR/596中修复。
在修复之前,__getitem__的第一个重载的类型签名是这样的:

@overload
    def __getitem__(self, idx: Scalar | tuple[Hashable, ...]) -> Series: ...

解决方案是将Hashable类型添加到__getitem__的重载中。
这里我列出了修复后__getitem__的所有3个重载。

@overload
    def __getitem__(self, idx: Scalar | Hashable) -> Series: ...
    @overload
    def __getitem__(self, rows: slice) -> DataFrame: ...
    @overload
    def __getitem__(
        self,
        idx: Series[_bool]
        | DataFrame
        | Index
        | np_ndarray_str
        | np_ndarray_bool
        | list[_ScalarOrTupleT],
    ) -> DataFrame: ...

问题中报告的示例现在满足第一个重载,因此类型检查器很高兴。

相关问题