如何使用unittest修补具有不同返回值的链式函数?

hkmswyz6  于 2021-07-12  发布在  Spark
关注(0)|答案(2)|浏览(362)

我有一个函数如下所示:

def my_function():
   sql_output = spark.sql('query').select('value').collect()[0]['value']

我尝试在unittest中使用mock和patch来修补变量 sql_output . 我正在修补 spark.sql 功能:

@patch("my_function.spark.sql")
def test_my_function(self, mock_sql_functions):
    from pyspark.sql.types import StringType
    from pyspark.sql.functions import lit

    mock_sql_functions.return_value.select.return_value.collect.return_value = None

我的目标是 sql_output 等于零。但是我不能这样做,因为返回值是none,但是 my_function 试图得到 [0]['value']None 价值观。
我尝试将返回值作为如下Dataframe:

sdf = spark.createDataFrame([('None', 'None', 'None')], ['value', 'value2', 'value3'])
sdf = sdf.withColumn("value", lit(None).cast(StringType()))

mock_sql_functions.return_value.select.return_value.collect.return_value = sdf

但它不起作用,因为我需要使用 [0]['value'] ,同时 collect() 我相信。所以我的问题是,如何设置这些倍数 return_value 不同的价值观?或者我该怎么设置 sql_output 值为 Noneunittest ?

webghufk

webghufk1#

使用当前代码解决此问题的最简单方法是使用以下内容:

import pyspark.sql

class SomethingTest(unittest.TestCase):

  @mock.patch.object(pyspark.sql, 'SparkSession')
  def test_my_function(self, mock_session):
    mock_session.sql.return_value.select.return_value.collect.return_value = [
        {'value': None},
    ]
    # This is the same value that thebadgateway's answer suggests.

    # the rest of your test

然而,在模拟测试中,侵入性较小通常更好。有简单的方法吗 DataFrame 您可以构造,用于 SparkSession.sql ? 这样,您还可以确保 .select() 以及 .collect() 子弹在做你想让他们做的事。
这看起来像:

@mock.patch.object(pyspark.sql, 'SparkSession')
  def test_my_function(self, mock_session):
    my_dataframe = pyspark.sql.DataFrame(...)  # build your frame
    mock_session.sql.return_value = my_dataframe

    # the rest of your test

虽然这并不是直接针对你的具体问题,但使用它通常是一个更好的主意 mock.patch.object ,因为这样可以直接引用要修补的对象,而不是依赖于以字符串的形式按名称进行搜索。

zxlwwiss

zxlwwiss2#

编辑:我明白现在的意思了。或许可以尝试使用以下测试类进行修补:

class TestSpark:
    def sql(self, arg): pass
    def select(self, arg): pass
    def collect(self): return [{"value": None}]

那么装潢师应该是 @mock.patch.object(the_module, "spark", return_value=TestSpark())

相关问题