使用模拟进行Python单元测试递归高阶函数

roejwanj  于 2023-01-18  发布在  Python
关注(0)|答案(1)|浏览(138)

我有以下函数来实现指数退避算法:

class Retry():
def exponential_backoff_retry(self, function, *args, n=1):
    MAX_TRIES = 8
    try:
        f = function(*args)
    except:
        if n > MAX_TRIES:
            return None
        n += 1
        time.sleep((2 ** n) + (random.randint(0, 1000) / 1000.0))
        return self.exponential_backoff_retry(function, *args, n)
    else:
        return f

我想做的是写一个单元测试来确认行为,给定一个传入的函数,例如一个发出API请求的函数,在异常的情况下它会重试很多次。
以下是我目前所做的尝试:

@mock.patch('requests.post')
@mock.patch('utils.retry.Retry.exponential_backoff_retry', side_effect=Exception('whoops'))
def test_exponential_backoff(self, mock_retry, req_post_mock):
    req_post_mock.return_value = {"status_code": 202}
    with self.assertRaises(Exception):
        mock_retry(req_post_mock)
    self.assertEqual(req_post_mock.return_value["status_code"], 202)
    self.assertEqual(mock_retry.call.count, 8)

任何建议都将不胜感激。

pengsaosao

pengsaosao1#

假设前

首先,提供的代码似乎有一个小错误(除了指示错误),您应该更改函数参数的顺序,如def exponential_backoff_retry(self, function,n=1, *args):而不是def exponential_backoff_retry(self, function, *args, n=1):,或者如果您更喜欢在递归函数调用return self.exponential_backoff_retry(function, *args, n=n)中将n作为命名参数

关于如何测试的主要答案

您可以创建一个模拟类来保存函数调用计数器

class MockFunctionHelperClass:
    def __init__(self, failed_attempts_before_success: int) -> None:
        self.failed_attempts_before_success = failed_attempts_before_success
        self.attempts_counter = 0

    def mock_function(self, return_value):
        self.attempts_counter += 1
        if self.attempts_counter <= self.failed_attempts_before_success:
            raise Exception
        return return_value

您可以在测试方法中使用helper类,如下所示:

def test_success_before_max_retries():
    failed_attempts_before_success = 2
    expected_result = True
    mock_function_helper_class = MockFunctionHelperClass(failed_attempts_before_success)
    result = Retry().exponential_backoff_retry(
    mock_function_helper_class.mock_function, 1, expected_result
)
    assert result == expected_result
    assert (
    mock_function_helper_class.attempts_counter
    == failed_attempts_before_success + 1
)

def test_failed_attempt_limit():
    expected_result = True
    mock_function_helper_class = MockFunctionHelperClass(3)
    result = Retry().exponential_backoff_retry(
    mock_function_helper_class.mock_function, 8, expected_result
    )
    assert result == None
    assert mock_function_helper_class.attempts_counter == 2

相关问题