如何从PySpark DataFrame的模式创建Python类?

svmlkihl  于 2023-10-15  发布在  Spark
关注(0)|答案(2)|浏览(115)

我有一个相对简单的PySpark应用程序,它读取一个输入表,执行一些转换并输出一个新表。
我想使用一个带有已知模拟数据的输入表的虚拟版本对这个脚本进行单元测试。但是,由于输入表具有相当复杂的嵌套模式,因此手动为输入创建DataFrame(例如,直接在测试代码中或从JSON文件)是相当麻烦的。
通常情况下,只有有限数量的字段与我的测试用例相关,所以我希望有一个解决方案,可以让我轻松地创建一个DataFrame,将相关字段设置为已知值,并将其余部分保留为一些我不必指定的默认值。
我想到的一个可能的解决方案是将DataFrame模式转换为Python类。这将允许我为每个测试用例创建类的示例,使用简单的Python类属性语法操作相关字段,然后将从测试用例类创建的DataFrame写入伪表。在代码中,它看起来像这样:

test_case1 = DataFrameClassRepresentation() # this class was generated from schema
test_case2 = DataFrameClassRepresentation()
test_case1.foo.bar = "some value"
test_case2.foo.baz = "some other value"

df = spark.createDataFrame([test_case1, test_case2])
df.writeTo("catalog.db.dummy_table")

在PySpark中,是否有一些内置的功能或其他简单的方法来从模式生成这样的类?有没有更简单的策略来实现我想做的事情,我还没有想到?

w46czmvw

w46czmvw1#

我不确定这是不是你想要的,因为它需要你定义多个类来匹配预期的模式(能够使用模式本身初始化类会更方便,但我不确定这是否可能)。但是,它允许您定义默认值并使用适当的模式创建测试框架。我还定义了一个getSchema函数,它从DataFrameClassRepresentation的每个示例中提取pyspark模式。

from dataclasses import dataclass, asdict

import pyspark
from pyspark.sql import SparkSession
from pyspark.sql.types import IntegerType, BooleanType, FloatType, StringType, StructType, StructField

spark = SparkSession.builder.appName('test').getOrCreate()

@dataclass
class ColASchema:
    col1: int = 123
    col2: str = "xyz"

@dataclass
class ColBSchema:
    col3: int = 456
    col4: str = "abc"

@dataclass
class ColCSchema:
    col1: float = 3.14
    col2: str = "def"

@dataclass
class ColDSchema:
    col3: ColCSchema = ColCSchema()
    col4: float = 1.62
    
@dataclass
class DataFrameClassRepresentation:
    colA: ColASchema = ColASchema()
    colB: ColBSchema = ColBSchema()
    colC: ColCSchema = ColCSchema()
    colD: ColDSchema = ColDSchema()

python_to_pyspark_type_map = {
    'bool': BooleanType(),
    'int': IntegerType(),
    'float': FloatType(),
    'str': StringType(),
}

def getSchema(d):
    if isinstance(d, dict):
        return StructType(
            [StructField(k, python_to_pyspark_type_map[type(v).__name__], True) if not isinstance(v, dict)
            else StructField(k, getSchema(v), True)
            for k,v in d.items()]
        )
    else:
        return python_to_pyspark_type_map[type(d).__name__]

然后,您可以通过以下方式创建测试框架:

test_case1 = DataFrameClassRepresentation()
test_case2 = DataFrameClassRepresentation(ColASchema(1,"abc"))

df = spark.createDataFrame([
    asdict(test_case1), 
    asdict(test_case2)
], getSchema(asdict(test_case1)))

+----------+----------+-----------+-------------------+
|colA      |colB      |colC       |colD               |
+----------+----------+-----------+-------------------+
|{123, xyz}|{456, abc}|{3.14, def}|{{3.14, def}, 1.62}| # default values
|{1, abc}  |{456, abc}|{3.14, def}|{{3.14, def}, 1.62}|
+----------+----------+-----------+-------------------+

root
 |-- colA: struct (nullable = true)
 |    |-- col1: integer (nullable = true)
 |    |-- col2: string (nullable = true)
 |-- colB: struct (nullable = true)
 |    |-- col3: integer (nullable = true)
 |    |-- col4: string (nullable = true)
 |-- colC: struct (nullable = true)
 |    |-- col1: float (nullable = true)
 |    |-- col2: string (nullable = true)
 |-- colD: struct (nullable = true)
 |    |-- col3: struct (nullable = true)
 |    |    |-- col1: float (nullable = true)
 |    |    |-- col2: string (nullable = true)
 |    |-- col4: float (nullable = true)
pkwftd7m

pkwftd7m2#

您可以从头开始创建一个DataFrameClassRepresentation类,使用所需的pyspark模式设置类变量。类中有一个递归方法,它可以解包模式,无论有多少层深,将其转换为每个字段(包括嵌套字段)具有默认值的Python字典,然后根据字典设置类变量。默认值也有内部类变量,这不是绝对必要的,但可能对您的用例有帮助。

import pyspark
from pyspark.sql import SparkSession
from pyspark.sql.types import IntegerType, BooleanType, FloatType, StringType, StructType, StructField

class DataFrameClassRepresentation:
    _pyspark_default_by_type = {
        BooleanType(): True,
        IntegerType(): 0,
        FloatType(): 3.14,
        StringType(): 'abc',
    }
    _pyspark_to_python_type = {
        BooleanType(): 'bool',
        IntegerType(): 'int',
        FloatType(): 'float',
        StringType(): 'str',
    }

    def _generate_sample_data(self, pyspark_schema):
        if isinstance(pyspark_schema, StructType):
            return {
                field.name: self._pyspark_default_by_type[field.dataType] 
                if not isinstance(field.dataType, StructType)
                else self._generate_sample_data(field.dataType)
                for field in pyspark_schema
            }
        else:
            return {
                self._pyspark_to_python_type[pyspark_schema]: 
                self._pyspark_default_by_type[pyspark_schema]
            }
    
    def __init__(self, pyspark_schema, defaults_by_type=None):
        if defaults_by_type is None:
            defaults_by_type = {
                self._pyspark_to_python_type[t]:v for t,v
                in self._pyspark_default_by_type.items()
            }
        schema = self._generate_sample_data(pyspark_schema)
        for k, v in schema.items():
            if not isinstance(v, dict):
                setattr(self, k, defaults_by_type[type(v).__name__])
            else:
                setattr(self, k, v)
    
    def get_data(self) -> dict:
        """Return the data structure of a class instance"""
        return {key: value for key, value in self.__dict__.items()}

然后,您可以声明类的示例以进行单元测试,并能够根据需要设置不同的字段:

spark = SparkSession.builder.appName('test').getOrCreate()

sample_schema = StructType([
StructField('colA', StructType([
     StructField('col1', FloatType(), True),
     StructField('col2', StringType(), True),
     StructField('col3', IntegerType(), True),
     StructField('colA1',
         StructType([
             StructField('col4', StringType(), True),
             StructField('col5', FloatType(), True),
         ]),
     ),
     ])),
 StructField('colB', IntegerType(), True),
 StructField('colC', BooleanType(), True)
 ])

test_case1 = DataFrameClassRepresentation(sample_schema)
test_case2 = DataFrameClassRepresentation(sample_schema)

test_case1.colA['col1'] = 10.00 # modify a nested field

test_df = spark.createDataFrame([
    test_case1.get_data(), 
    test_case2.get_data()
], sample_schema)

下面是test_df及其schema的打印输出:

+---------------------------+----+----+
|colA                       |colB|colC|
+---------------------------+----+----+
|{10.0, abc, 0, {abc, 3.14}}|0   |true|
|{3.14, abc, 0, {abc, 3.14}}|0   |true|
+---------------------------+----+----+

root
 |-- colA: struct (nullable = true)
 |    |-- col1: float (nullable = true)
 |    |-- col2: string (nullable = true)
 |    |-- col3: integer (nullable = true)
 |    |-- colA1: struct (nullable = true)
 |    |    |-- col4: string (nullable = true)
 |    |    |-- col5: float (nullable = true)
 |-- colB: integer (nullable = true)
 |-- colC: boolean (nullable = true)

相关问题