使用yield对Python函数进行正确的类型注解

omjgkv6w  于 2023-03-28  发布在  Python
关注(0)|答案(3)|浏览(167)

阅读了Eli Bendersky的文章on implementing state machines via Python coroutines后,我想...

  • 请参见在Python3下运行的示例
  • 并为生成器添加适当的类型注解

我成功地完成了第一部分(* 但没有使用async def s或yield from s,我基本上只是移植了代码-所以任何改进都是最受欢迎的 *)。
但是我需要一些关于协程的类型注解的帮助:

#!/usr/bin/env python3

from typing import Callable, Generator

def unwrap_protocol(header: int=0x61,
                    footer: int=0x62,
                    dle: int=0xAB,
                    after_dle_func: Callable[[int], int]=lambda x: x,
                    target: Generator=None) -> Generator:
    """ Simplified protocol unwrapping co-routine."""
    #
    # Outer loop looking for a frame header
    #
    while True:
        byte = (yield)
        frame = []  # type: List[int]

        if byte == header:
            #
            # Capture the full frame
            #
            while True:
                byte = (yield)
                if byte == footer:
                    target.send(frame)
                    break
                elif byte == dle:
                    byte = (yield)
                    frame.append(after_dle_func(byte))
                else:
                    frame.append(byte)

def frame_receiver() -> Generator:
    """ A simple co-routine "sink" for receiving full frames."""
    while True:
        frame = (yield)
        print('Got frame:', ''.join('%02x' % x for x in frame))

bytestream = bytes(
    bytearray((0x70, 0x24,
               0x61, 0x99, 0xAF, 0xD1, 0x62,
               0x56, 0x62,
               0x61, 0xAB, 0xAB, 0x14, 0x62,
               0x7)))

frame_consumer = frame_receiver()
next(frame_consumer)  # Get to the yield

unwrapper = unwrap_protocol(target=frame_consumer)
next(unwrapper)  # Get to the yield

for byte in bytestream:
    unwrapper.send(byte)

它运行正常…

$ ./decoder.py 
Got frame: 99afd1
Got frame: ab14

...以及类型检查:

$ mypy --disallow-untyped-defs decoder.py 
$

但我很确定我可以做得比在类型规范中使用Generator基类更好(就像我对Callable所做的那样)。我知道它需要3个类型参数(Generator[A,B,C]),但我不确定它们在这里是如何指定的。
任何帮助最欢迎。

wtlkbnrh

wtlkbnrh1#

我自己找到了答案。
我搜索了一下,但是在official typing documentation for Python 3.5.2中没有找到Generator的3个类型参数的文档-除了一个真正神秘的提到...

class typing.Generator(Iterator[T_co], Generic[T_co, T_contra, V_co])

幸运的是,the original PEP484(这一切的开始)更有帮助:

  • “生成器函数的返回类型可以通过www.example.com模块提供的泛型类型Generator[yield_type,send_type,return_type]进行注解typing.py:*
def echo_round() -> Generator[int, float, str]:
    res = yield
    while res:
        res = yield round(res)
    return 'OK'

基于此,我能够注解我的生成器,并看到mypy确认了我的分配:

from typing import Callable, Generator

# A protocol decoder:
#
# - yields Nothing
# - expects ints to be `send` in his yield waits
# - and doesn't return anything.
ProtocolDecodingCoroutine = Generator[None, int, None]

# A frame consumer (passed as an argument to a protocol decoder):
#
# - yields Nothing
# - expects List[int] to be `send` in his waiting yields
# - and doesn't return anything.
FrameConsumerCoroutine = Generator[None, List[int], None]

def unwrap_protocol(header: int=0x61,
                    footer: int=0x62,
                    dle :int=0xAB,
                    after_dle_func: Callable[[int], int]=lambda x: x,
                    target: FrameConsumerCoroutine=None) -> ProtocolDecodingCoroutine:
    ...

def frame_receiver() -> FrameConsumerCoroutine:
    ...

我通过交换类型的顺序来测试我的作业-然后正如预期的那样,mypy抱怨并要求正确的类型(如上所示)。
完整的代码is accessible from here.
我将这个问题留到几天后再讨论,以防有人想插话--特别是在使用Python 3.5的新协程风格(async def等)方面--我希望能得到一些关于它们在这里如何使用的提示。

n6lpvg4x

n6lpvg4x2#

如果你有一个使用yield的简单函数,那么你可以使用Iterator类型来注解它的结果,而不是Generator

from collections.abc import Iterator  # Python >=3.9

def count_up() -> Iterator[int]:
    for x in range(10):
        yield x

在Python〈3.9中,必须以不同的方式导入Iterator

from typing import Iterator  # Python <3.9
lc8prwob

lc8prwob3#

在撰写本文时,Python documentation也明确提到了如何处理异步情况(在接受的答案中已经提到了非异步示例)。
从那里引用:

async def echo_round() -> AsyncGenerator[int, float]:
    sent = yield 0
    while sent >= 0.0:
        rounded = await round(sent)
        sent = yield rounded

(第一个参数是yield-type,第二个参数是send-type)或简单情况下(send-type为None)

async def infinite_stream(start: int) -> AsyncIterator[int]:
    while True:
        yield start
        start = await increment(start)

相关问题