在Python Asyncio中限制异步函数

aiqt4smr  于 2023-03-13  发布在  Python
关注(0)|答案(3)|浏览(115)

我有一个awaitableslist,我想传递给asyncio.AbstractEventLoop,但我需要限制对第三方API的请求。
我希望避免等待将future传递给循环的情况,因为在此期间我阻塞了循环等待。我有什么选择?SemaphoresThreadPools将限制并发运行的数量,但这不是我的问题。我需要将请求限制为100/秒,但完成请求需要多长时间并不重要。
这是一个使用标准库的非常简洁的(非)工作示例,演示了问题所在。它应该以100/sec的速度进行节流,但实际上节流到了116.651/sec。在asyncio中,节流异步请求调度的最佳方式是什么
工作代码:

import asyncio
from threading import Lock

class PTBNL:

    def __init__(self):
        self._req_id_seq = 0
        self._futures = {}
        self._results = {}
        self.token_bucket = TokenBucket()
        self.token_bucket.set_rate(100)

    def run(self, *awaitables):

        loop = asyncio.get_event_loop()

        if not awaitables:
            loop.run_forever()
        elif len(awaitables) == 1:
            return loop.run_until_complete(*awaitables)
        else:
            future = asyncio.gather(*awaitables)
            return loop.run_until_complete(future)

    def sleep(self, secs) -> True:

        self.run(asyncio.sleep(secs))
        return True

    def get_req_id(self) -> int:

        new_id = self._req_id_seq
        self._req_id_seq += 1
        return new_id

    def start_req(self, key):

        loop = asyncio.get_event_loop()
        future = loop.create_future()
        self._futures[key] = future
        return future

    def end_req(self, key, result=None):

        future = self._futures.pop(key, None)
        if future:
            if result is None:
                result = self._results.pop(key, [])
            if not future.done():
                future.set_result(result)

    def req_data(self, req_id, obj):
        # Do Some Work Here
        self.req_data_end(req_id)
        pass

    def req_data_end(self, req_id):
        print(req_id, " has ended")
        self.end_req(req_id)

    async def req_data_async(self, obj):

        req_id = self.get_req_id()
        future = self.start_req(req_id)

        self.req_data(req_id, obj)

        await future
        return future.result()

    async def req_data_batch_async(self, contracts):

        futures = []
        FLAG = False

        for contract in contracts:
            req_id = self.get_req_id()
            future = self.start_req(req_id)
            futures.append(future)

            nap = self.token_bucket.consume(1)

            if FLAG is False:
                FLAG = True
                start = asyncio.get_event_loop().time()

            asyncio.get_event_loop().call_later(nap, self.req_data, req_id, contract)

        await asyncio.gather(*futures)
        elapsed = asyncio.get_event_loop().time() - start

        return futures, len(contracts)/elapsed

class TokenBucket:

    def __init__(self):
        self.tokens = 0
        self.rate = 0
        self.last = asyncio.get_event_loop().time()
        self.lock = Lock()

    def set_rate(self, rate):
        with self.lock:
            self.rate = rate
            self.tokens = self.rate

    def consume(self, tokens):
        with self.lock:
            if not self.rate:
                return 0

            now = asyncio.get_event_loop().time()
            lapse = now - self.last
            self.last = now
            self.tokens += lapse * self.rate

            if self.tokens > self.rate:
                self.tokens = self.rate

            self.tokens -= tokens

            if self.tokens >= 0:
                return 0
            else:
                return -self.tokens / self.rate

if __name__ == '__main__':

    asyncio.get_event_loop().set_debug(True)
    app = PTBNL()

    objs = [obj for obj in range(500)]

    l,t = app.run(app.req_data_batch_async(objs))

    print(l)
    print(t)

编辑:我在这里添加了一个使用信号量的TrottleTestApp的简单示例,但仍然不能限制执行:

import asyncio
import time

class ThrottleTestApp:

    def __init__(self):
        self._req_id_seq = 0
        self._futures = {}
        self._results = {}
        self.sem = asyncio.Semaphore()

    async def allow_requests(self, sem):
        """Permit 100 requests per second; call 
           loop.create_task(allow_requests())
        at the beginning of the program to start this routine.  That call returns
        a task handle that can be canceled to end this routine.

        asyncio.Semaphore doesn't give us a great way to get at the value other
        than accessing sem._value.  We do that here, but creating a wrapper that
        adds a current_value method would make this cleaner"""

        while True:
            while sem._value < 100: sem.release()
            await asyncio.sleep(1)  # Or spread more evenly 
                                    # with a shorter sleep and 
                                    # increasing the value less

    async def do_request(self, req_id, obj):
        await self.sem.acquire()

        # this is the work for the request
        self.req_data(req_id, obj)

    def run(self, *awaitables):

        loop = asyncio.get_event_loop()

        if not awaitables:
            loop.run_forever()
        elif len(awaitables) == 1:
            return loop.run_until_complete(*awaitables)
        else:
            future = asyncio.gather(*awaitables)
            return loop.run_until_complete(future)

    def sleep(self, secs: [float]=0.02) -> True:

        self.run(asyncio.sleep(secs))
        return True

    def get_req_id(self) -> int:

        new_id = self._req_id_seq
        self._req_id_seq += 1
        return new_id

    def start_req(self, key):

        loop = asyncio.get_event_loop()
        future = loop.create_future()
        self._futures[key] = future
        return future

    def end_req(self, key, result=None):

        future = self._futures.pop(key, None)
        if future:
            if result is None:
                result = self._results.pop(key, [])
            if not future.done():
                future.set_result(result)

    def req_data(self, req_id, obj):
        # This is the method that "does" something
        self.req_data_end(req_id)
        pass

    def req_data_end(self, req_id):

        print(req_id, " has ended")
        self.end_req(req_id)

    async def req_data_batch_async(self, objs):

        futures = []
        FLAG = False

        for obj in objs:
            req_id = self.get_req_id()
            future = self.start_req(req_id)
            futures.append(future)

            if FLAG is False:
                FLAG = True
                start = time.time()

            self.do_request(req_id, obj)

        await asyncio.gather(*futures)
        elapsed = time.time() - start
        print("Roughly %s per second" % (len(objs)/elapsed))

        return futures

if __name__ == '__main__':

    asyncio.get_event_loop().set_debug(True)
    app = ThrottleTestApp()

    objs = [obj for obj in range(10000)]

    app.run(app.req_data_batch_async(objs))
pdsfdshx

pdsfdshx1#

您可以通过实现leaky bucket algorithm

import asyncio
import contextlib
import collections
import time

from types import TracebackType
from typing import Dict, Optional, Type

try:  # Python 3.7
    base = contextlib.AbstractAsyncContextManager
    _current_task = asyncio.current_task
except AttributeError:
    base = object  # type: ignore
    _current_task = asyncio.Task.current_task  # type: ignore

class AsyncLeakyBucket(base):
    """A leaky bucket rate limiter.

    Allows up to max_rate / time_period acquisitions before blocking.

    time_period is measured in seconds; the default is 60.

    """
    def __init__(
        self,
        max_rate: float,
        time_period: float = 60,
        loop: Optional[asyncio.AbstractEventLoop] = None
    ) -> None:
        self._loop = loop
        self._max_level = max_rate
        self._rate_per_sec = max_rate / time_period
        self._level = 0.0
        self._last_check = 0.0
        # queue of waiting futures to signal capacity to
        self._waiters: Dict[asyncio.Task, asyncio.Future] = collections.OrderedDict()

    def _leak(self) -> None:
        """Drip out capacity from the bucket."""
        if self._level:
            # drip out enough level for the elapsed time since
            # we last checked
            elapsed = time.time() - self._last_check
            decrement = elapsed * self._rate_per_sec
            self._level = max(self._level - decrement, 0)
        self._last_check = time.time()

    def has_capacity(self, amount: float = 1) -> bool:
        """Check if there is enough space remaining in the bucket"""
        self._leak()
        requested = self._level + amount
        # if there are tasks waiting for capacity, signal to the first
        # there there may be some now (they won't wake up until this task
        # yields with an await)
        if requested < self._max_level:
            for fut in self._waiters.values():
                if not fut.done():
                    fut.set_result(True)
                    break
        return self._level + amount <= self._max_level

    async def acquire(self, amount: float = 1) -> None:
        """Acquire space in the bucket.

        If the bucket is full, block until there is space.

        """
        if amount > self._max_level:
            raise ValueError("Can't acquire more than the bucket capacity")

        loop = self._loop or asyncio.get_event_loop()
        task = _current_task(loop)
        assert task is not None
        while not self.has_capacity(amount):
            # wait for the next drip to have left the bucket
            # add a future to the _waiters map to be notified
            # 'early' if capacity has come up
            fut = loop.create_future()
            self._waiters[task] = fut
            try:
                await asyncio.wait_for(
                    asyncio.shield(fut),
                    1 / self._rate_per_sec * amount,
                    loop=loop
                )
            except asyncio.TimeoutError:
                pass
            fut.cancel()
        self._waiters.pop(task, None)

        self._level += amount

        return None

    async def __aenter__(self) -> None:
        await self.acquire()
        return None

    async def __aexit__(
        self,
        exc_type: Optional[Type[BaseException]],
        exc: Optional[BaseException],
        tb: Optional[TracebackType]
    ) -> None:
        return None

注意,我们会随机地从存储桶中泄漏容量,因此不需要运行单独的异步任务来降低容量;相反,当测试足够剩余容量时,容量被泄漏。
请注意,等待容量的任务保存在有序字典中,当可能再次出现空闲容量时,第一个仍在等待的任务会被提前唤醒。
您可以将其用作上下文管理器;当桶是满块时尝试获取桶,直到再次释放足够的容量:

bucket = AsyncLeakyBucket(100)

# ...

async with bucket:
    # only reached once the bucket is no longer full

也可以直接调用acquire()

await bucket.acquire()  # blocks until there is space in the bucket

或者你可以简单的测试一下是否有空格在前面

if bucket.has_capacity():
    # reject a request due to rate limiting

请注意,您可以通过增加或减少您“滴入”桶中的量来将某些请求计数为“较重”或“较轻”:

await bucket.acquire(10)
if bucket.has_capacity(0.5):

不过要小心这个;当混合大液滴和小液滴时,在最大速率或接近最大速率时,小液滴倾向于在大液滴之前流动,这是因为在存在用于较大液滴的空间之前存在用于较小液滴的足够的空闲容量的可能性较大。
演示:

>>> import asyncio, time
>>> bucket = AsyncLeakyBucket(5, 10)
>>> async def task(id):
...     await asyncio.sleep(id * 0.01)
...     async with bucket:
...         print(f'{id:>2d}: Drip! {time.time() - ref:>5.2f}')
...
>>> ref = time.time()
>>> tasks = [task(i) for i in range(15)]
>>> result = asyncio.run(asyncio.wait(tasks))
 0: Drip!  0.00
 1: Drip!  0.02
 2: Drip!  0.02
 3: Drip!  0.03
 4: Drip!  0.04
 5: Drip!  2.05
 6: Drip!  4.06
 7: Drip!  6.06
 8: Drip!  8.06
 9: Drip! 10.07
10: Drip! 12.07
11: Drip! 14.08
12: Drip! 16.08
13: Drip! 18.08
14: Drip! 20.09

桶在开始时被快速地突发填充,使得剩余的任务被更均匀地分布;每2秒就释放足够的容量用于处理另一任务。
最大猝发大小等于最大速率值,在上述演示中设置为5。如果不想允许猝发,请将最大速率设置为1,并将时间段设置为滴注之间的最短时间:

>>> bucket = AsyncLeakyBucket(1, 1.5)  # no bursts, drip every 1.5 seconds
>>> async def task():
...     async with bucket:
...         print(f'Drip! {time.time() - ref:>5.2f}')
...
>>> ref = time.time()
>>> tasks = [task() for _ in range(5)]
>>> result = asyncio.run(asyncio.wait(tasks))
Drip!  0.00
Drip!  1.50
Drip!  3.01
Drip!  4.51
Drip!  6.02

我抽时间把它打包成一个Python项目:https://github.com/mjpieters/aiolimiter

fafcakar

fafcakar2#

另一个解决方案--使用有界信号量--由同事、导师和朋友提供,如下所示:

import asyncio

class AsyncLeakyBucket(object):

    def __init__(self, max_tasks: float, time_period: float = 60, loop: asyncio.events=None):
        self._delay_time = time_period / max_tasks
        self._sem = asyncio.BoundedSemaphore(max_tasks)
        self._loop = loop or asyncio.get_event_loop()
        self._loop.create_task(self._leak_sem())

    async def _leak_sem(self):
        """
        Background task that leaks semaphore releases based on the desired rate of tasks per time_period
        """
        while True:
            await asyncio.sleep(self._delay_time)
            try:
                self._sem.release()
            except ValueError:
                pass

    async def __aenter__(self) -> None:
        await self._sem.acquire()

    async def __aexit__(self, exc_type, exc, tb) -> None:
        pass

仍然可以与@Martijn答案中的相同async with bucket代码一起使用

avkwfej4

avkwfej43#

一个简单的解决方案,用于管理每秒最大请求数和最大同时连接到API,我将其与盈透证券API一起使用。

import asyncio
import datetime as dt
import random

async def send_request(num):
    print(f"Request  {num:>2} at {dt.datetime.now()}")
    await asyncio.sleep(random.choice([0.1, 0.2]))
    print(f"Response {num:>2} at {dt.datetime.now()}")

def requests_per_second(request_datetimes):
    rps = 0
    if len(request_datetimes) > 0:
        rps = 1 / (dt.datetime.now() - request_datetimes[-1]).total_seconds()
    return rps

async def rate_limited_gather(*args, rate_limit=50, max_connections=10):
    """Manage max requests per second and max open connections for an API"""
    awaitables = []
    request_datetimes = []
    loop = asyncio.get_event_loop()
    connections = 0
    for arg in args:
        while (
            requests_per_second(request_datetimes) > rate_limit or connections >= max_connections
        ):
            await asyncio.sleep(1 / rate_limit)
            connections = sum([not a.done() for a in awaitables])
        print(f"Requests per second: {requests_per_second(request_datetimes)}")
        request_datetimes.append(dt.datetime.now())
        awaitables.append(loop.create_task(arg))
        connections = sum([not a.done() for a in awaitables])
    await asyncio.gather(*awaitables, return_exceptions=True)

if __name__ == "__main__":
    loop = asyncio.get_event_loop()
    loop.run_until_complete(rate_limited_gather(*[send_request(x) for x in range(10)]))

输出示例:

Requests per second: 0
Request   0 at 2023-03-11 10:34:49.348671
Requests per second: 49.696849219759464
Request   1 at 2023-03-11 10:34:49.368800
Requests per second: 49.69931911932807
Request   2 at 2023-03-11 10:34:49.388930
Requests per second: 49.69931911932807
Request   3 at 2023-03-11 10:34:49.409057
Requests per second: 49.72403162448411
Request   4 at 2023-03-11 10:34:49.429170
Response  0 at 2023-03-11 10:34:49.449260
Requests per second: 49.691910157026435
Request   5 at 2023-03-11 10:34:49.449298
Response  1 at 2023-03-11 10:34:49.469389
Requests per second: 49.68450340338848
Request   6 at 2023-03-11 10:34:49.469436
Response  2 at 2023-03-11 10:34:49.489529
Requests per second: 49.67956679417755
Request   7 at 2023-03-11 10:34:49.489566
Requests per second: 49.73392350922564
Request   8 at 2023-03-11 10:34:49.509682
Requests per second: 49.7116723006562
Request   9 at 2023-03-11 10:34:49.529858
Response  6 at 2023-03-11 10:34:49.569973
Response  7 at 2023-03-11 10:34:49.590072
Response  3 at 2023-03-11 10:34:49.609170
Response  4 at 2023-03-11 10:34:49.629267
Response  5 at 2023-03-11 10:34:49.650361
Response  8 at 2023-03-11 10:34:49.710456
Response  9 at 2023-03-11 10:34:49.730560

相关问题