PySpark模拟:异常测试成功,但未处理异常

qyuhtwio  于 2023-01-21  发布在  Apache
关注(0)|答案(2)|浏览(108)

我正在使用python 2.7(不要问我为什么,我是一个承包商,我只是与他们给予我的工作)。
我正在尝试实现一个pyspark函数,该函数利用spark-bigqueryconnector提交一个简单的查询,该查询利用Spark SQL数据源API。
我经历了最奇怪的事情;我编写了这个函数,并在实际运行它时确认它确实可以在服务器上工作。我想确保如果用户提供了一个不存在的表名,则会在处理服务器返回的表名时抛出异常。而我做到了(我知道这不是TDD,但就用这个吧)。然后,我开始为它编写测试,显然我必须生成一个模拟异常,我这样做了:

模块/查询_bq

from py4j.protocol import Py4JJavaError
from pyspark.sql import SparkSession

def submit_bq_query(spark, table, filter_string):
    try:
        df = spark.read.format('bigquery').option('table', table).option('filter', filter_string).load()
        return df
    except Py4JJavaError as e:
        java_error_msg = str(e).split('\n')[1]
        if "java.lang.RuntimeException" in java_error_msg and ("{} not found".format(table)) in java_error_msg:
            raise Exception("RuntimeException: Table {} not found!".format(table))

就像我说的,这就像一个魅力。现在,对它的测试看起来像这样:

模块/测试查询bq

import pytest
from mock import patch, mock
from py4j.java_gateway import GatewayProperty, GatewayClient, JavaObject
from py4j.protocol import Py4JJavaError
from pyspark.sql.dataframe import DataFrame
from pyspark.sql.types import StructType

def mock_p4j_java_error_generator(msg):
    gateway_property = GatewayProperty(auto_field="Mock", pool="Mock")
    client = GatewayClient(gateway_property=gateway_property)
    java_object = JavaObject("RunTimeError", client)
    exception = Py4JJavaError(msg, java_exception=java_object)
    return Exception(exception)

def test_exception_is_thrown_if_table_not_present():

    # Given
    mock_table_name = 'spark_bq_test.false_table_name'
    mock_filter = "word is 'V'"
    mock_errmsg = "Table {} not found".format(mock_table_name)

    # Mocking
    mock_spark = mock.Mock()
    mock_spark_reader = mock.Mock()

    # Mocking return-values setup
    mock_spark.read.format.return_value = mock_spark_reader
    mock_spark_reader.option.return_value = mock_spark_reader
    mock_spark_reader.load.side_effect = mock_p4j_java_error_generator(mock_errmsg)

    # When
    with pytest.raises(Exception) as exception:
        submit_bq_query(mock_spark, mock_table_name, mock_filter)
    assert exception.value.message.errmsg == mock_errmsg

运行测试成功了,但是当我试着调试它时,我注意到异常后面的代码被捕获了:

模块/查询_bq

...
    except Py4JJavaError as e:
        java_error_msg = str(e).split('\n')[1] .  # This line is never reached!
        if "java.lang.RuntimeException" in java_error_msg and ("{} not found".format(table)) in java_error_msg:
            raise Exception("RuntimeException: Table {} not found!".format(table))
...

从未达到。然而,* 测试仍然成功 *。
简而言之,异常在测试中被模拟和抛出。它也被捕获,但没有被处理。测试的Assert通过,测试成功,就像它没有被处理一样,但我从来没有检查过模拟异常的内部。再次,让我注意到module/query_bq对服务器工作得很好;返回dataframes,并且在表不存在的情况下处理异常。这里的重点是测试。
我需要对module/query_bq中异常的处理部分做一些额外的事情,但是我不能,因为我不知道发生了什么。有人能解释一下吗?

4nkexdtk

4nkexdtk1#

经过3天的挣扎,我整理出来了,主要的问题是:

  • 我没有正确地模拟spark.read进程签名,并且;
  • 我没有正确地示例化Py 4JJavaError的模拟示例。

以下是我是如何做到这两点的:

.../实用程序/大查询实用程序.py

import logging

from py4j.protocol import Py4JJavaError, Py4JNetworkError

def load_bq_table(spark, table, filter_string):
    tries = 3
    for i in range(tries):
        try:
            logging.info("SQL statement being executed...")
            df = get_df(spark, table, filter_string)
            logging.info("Table-ID: {}, Rows:{} x Cols:{}".format(table, df.count(), len(df.columns)))
            logging.debug("Table-ID: {}, Schema: {}".format(table, df.schema))
            return df
        except Py4JJavaError as e:
            java_exception_str = get_exception_msg(e)
            is_runtime_exception = "java.lang.RuntimeException" in java_exception_str
            table_not_found = ("{} not found".format(table)) in java_exception_str
            if is_runtime_exception and table_not_found:
                logging.error(java_exception_str)
                raise RuntimeError("Table {} not found!".format(table))
        except Py4JNetworkError as ne:
            if i is tries-1:
                java_exception_str = ne.cause
                runtime_error_str = "Error while trying to reach server... {}"
                logging.error(java_exception_str)
                raise EnvironmentError(runtime_error_str.format(java_exception_str))
            continue

def get_exception_msg(e):
    return str(e.java_exception)

def get_df(spark, table, filter_string):
    return (spark.read
            .format('bigquery')
            .option('table', table)
            .option('filter', filter_string)
            .load())

关于测试:.../测试/实用程序/测试大查询实用程序.py

import pytest
    from mock import patch, mock

    from <...>.utils.bigquery_util import load_bq_table
    from <...>.test.utils.mock_py4jerror import *

    def test_runtime_error_exception_is_thrown_if_table_not_present():

        # Given
        mock_table_name = 'spark_bq_test.false_table_name'
        mock_filter = "word is 'V'"

        # Mocking
        py4j_error_exception = get_mock_py4j_error_exception(get_mock_gateway_client(), mock_target_id="o123")
        mock_errmsg = "java.lang.RuntimeException: Table {} not found".format(mock_table_name)

        # When
        with mock.patch('red_agent.common.utils.bigquery_util.get_exception_msg', return_value=mock_errmsg):
            with mock.patch('red_agent.common.utils.bigquery_util.get_df', side_effect=py4j_error_exception):
                with pytest.raises(RuntimeError):
                    mock_spark = mock.Mock()
                    df = load_bq_table(mock_spark, mock_table_name, mock_filter)

..最后是模拟Py 4JJavaError:.../测试/实用程序/模拟错误. py

import mock
from py4j.protocol import Py4JJavaError, Py4JNetworkError

def get_mock_gateway_client():
    mock_client = mock.Mock()
    mock_client.send_command.return_value = "0"
    mock_client.converters = []
    mock_client.is_connected.return_value = True
    mock_client.deque = mock.Mock()
    return mock_client

def get_mock_java_object(mock_client, mock_target_id):
    mock_java_object = mock.Mock()
    mock_java_object._target_id = mock_target_id
    mock_java_object._gateway_client = mock_client
    return mock_java_object

def get_mock_py4j_error_exception(mock_client, mock_target_id):
    mock_java_object = get_mock_java_object(mock_client, mock_target_id)
    mock_errmsg = "An error occurred while calling {}.load.".format(mock_target_id)
    return Py4JJavaError(mock_errmsg, java_exception=mock_java_object)

def get_mock_py4j_network_exception(mock_target_id):
    mock_errmsg = "An error occurred while calling {}.load.".format(mock_target_id)
    return Py4JNetworkError(mock_errmsg)

希望这能帮助到一些人...

czq61nw1

czq61nw12#

当我尝试使用类似create_autospec(spec=Py4JJavaError)的代码时,我得到了错误TypeError: exceptions must derive from BaseException。这个被接受的答案帮助我编写了这个解决方案:

def my_function():
    try:
        problematic_code()
    except Py4JJavaError as e:
        if "java.lang.RuntimeException" in str(e.java_exception):
            raise MyException(f"The problematic code failed with {e}.") from e
        raise e

def test(self):
    java_exception = MagicMock(_target_id="_")
    java_exception.__str__.return_value = "java.lang.RuntimeException"
    py4j_java_error = Py4JJavaError("_", java_exception=java_exception)

    with patch.object(Py4JJavaError, "__str__", return_value="_"), \
            patch("my_module.problematic_code", side_effect=py4j_java_error):
        self.assertRaises(MyException, my_function)

我意识到Java网关客户机只在Py4JJavaError__str__方法中使用,所以我只是模拟了整个方法,而不是模拟客户机。
参考:https://github.com/py4j/py4j/blob/master/py4j-python/src/py4j/protocol.py

相关问题