如何在PySpark上并行计算同一 Dataframe 上的不同聚合?

ljsrvy3e  于 2022-11-01  发布在  Spark
关注(0)|答案(1)|浏览(218)

我想在PySpark上手工计算一个大型 Dataframe 的一些自定义汇总统计信息。为了简单起见,让我使用一个更简单的虚拟数据集,如下所示:

from pyspark.sql import SparkSession
from pyspark.sql.dataframe import DataFrame
from pyspark.sql.types import DataType, NumericType, DateType, TimestampType
import pyspark.sql.types as t
import pyspark.sql.functions as f
from datetime import datetime

spark = (
    SparkSession.builder
    .appName("pyspark")
    .master("local[*]")
    .getOrCreate()
)
dd = [
    ("Alice", 18.0, datetime(2022, 1, 1)),
    ("Bob", None, datetime(2022, 2, 1)),
    ("Mark", 33.0, None),
    (None, 80.0, datetime(2022, 4, 1)),
]
schema = t.StructType(
    [
        t.StructField("T", t.StringType()),
        t.StructField("C", t.DoubleType()),
        t.StructField("D", t.DateType()),
    ]
)
df = spark.createDataFrame(dd, schema)

好的,问题是,我想计算一些聚合:missing countsstddevmaxmin,当然,我希望并行执行。我可以采用两种方法:

方法1:一个选择查询

这样,我让Spark引擎通过生成一个大的选择查询来进行并行计算。

def df_dtypes(df: DataFrame) -> List[Tuple[str, DataType]]:
    """
    Like df.dtypes attribute of Spark DataFrame, but returning DataType objects instead
    of strings.
    """
    return [(str(f.name), f.dataType) for f in df.schema.fields]

def get_missing(df: DataFrame) -> Tuple:
    suffix = "__missing"
    result = (
        *(
            (
                f.count(
                    f.when(
                        (f.isnan(c) | f.isnull(c)),
                        c,
                    )
                )
                / f.count("*")
                * 100
                if isinstance(t, NumericType)  # isnan only works for numeric types
                else f.count(
                    f.when(
                        f.isnull(c),
                        c,
                    )
                )
                / f.count("*")
                * 100
            )
            .cast("double")
            .alias(c + suffix)
            for c, t in df_dtypes(df)
        ),
    )

    return result

def get_min(df: DataFrame) -> Tuple:
    suffix = "__min"
    result = (
        *(
            (f.min(c) if isinstance(t, (NumericType, DateType, TimestampType)) else f.lit(None))
            .cast(t)
            .alias(c + suffix)
            for c, t in df_dtypes(df)
        ),
    )
    return result

def get_max(df: DataFrame) -> Tuple:
    suffix = "__max"
    result = (
        *(
            (f.max(c) if isinstance(t, (NumericType, DateType, TimestampType)) else f.lit(None))
            .cast(t)
            .alias(c + suffix)
            for c, t in df_dtypes(df)
        ),
    )
    return result

def get_std(df: DataFrame) -> Tuple:
    suffix = "__std"
    result = (
        *(
            (f.stddev(c) if isinstance(t, NumericType) else f.lit(None)).cast(t).alias(c + suffix)
            for c, t in df_dtypes(df)
        ),
    )
    return result

# build the big query

query = get_min(df) + get_max(df) + get_missing(df) + get_std(df)

# run the job

df.select(*query).show()

据我所知,由于Spark的内部工作正常,这个作业将并行运行。这种方法有效吗?这种方法的问题可能是它创建了大量带后缀的列,这会是一个瓶颈吗?

方法2:使用线程

在这种方法中,我可以利用Python线程尝试并发执行每个计算。

from pyspark import InheritableThread
from queue import Queue

def get_min(df: DataFrame, q: Queue) -> None:
    result = df.select(
        f.lit("min").alias("summary"),
        *(
            (f.min(c) if isinstance(t, (NumericType, DateType, TimestampType)) else f.lit(None))
            .cast(t)
            .alias(c)
            for c, t in df_dtypes(df)
        ),
    ).collect()
    q.put(result)

def get_max(df: DataFrame, q: Queue) -> None:
    result = df.select(
        f.lit("max").alias("summary"),
        *(
            (f.max(c) if isinstance(t, (NumericType, DateType, TimestampType)) else f.lit(None))
            .cast(t)
            .alias(c)
            for c, t in df_dtypes(df)
        ),
    ).collect()
    q.put(result)

def get_std(df: DataFrame, q: Queue) -> None:
    result = df.select(
        f.lit("std").alias("summary"),
        *(
            (f.stddev(c) if isinstance(t, NumericType) else f.lit(None)).cast(t).alias(c)
            for c, t in df_dtypes(df)
        ),
    ).collect()
    q.put(result)

def get_missing(df: DataFrame, q: Queue) -> None:
    result = df.select(
        f.lit("missing").alias("summary"),
        *(
            (
                f.count(
                    f.when(
                        (f.isnan(c) | f.isnull(c)),
                        c,
                    )
                )
                / f.count("*")
                * 100
                if isinstance(t, NumericType)  # isnan only works for numeric types
                else f.count(
                    f.when(
                        f.isnull(c),
                        c,
                    )
                )
                / f.count("*")
                * 100
            )
            .cast("double")
            .alias(c)
            for c, t in df_dtypes(df)
        ),
    ).collect()
    q.put(result)

# caching the dataframe to reuse it for all the jobs?

df.cache()

# I use a queue to retrieve the results from the threads

q = Queue()
threads = [
    InheritableThread(target=fun, args=(df, q)).start()
    for fun in (get_min, get_max, get_missing, get_std)
]

# and then some code to recover the results from the queue

这种方法的优点是不会产生很多带后缀的列,而只是原始列。但是我不确定这种方法如何处理GIL,这实际上是并行的吗?
你能告诉我你更喜欢哪一个吗?或者关于不同的计算方法的一些建议?
最后,我想用所有这些聚合的统计信息构建一个JSON。JSON的结构无关紧要,它取决于所采用的方法。对于第一个,我会得到类似于{“T__min”:无,“T__max”:无,“T__缺失”:1、“T__标准”:无,“C__min”:18.0,“C__最大值”:80.0,...}所以这样我就有了大量的字段,选择查询也会很大。对于第二种方法,我会用这些统计信息为每个变量获取一个JSON。

ar7v8xwq

ar7v8xwq1#

我不是很熟悉InheritableThreadQueue,但据我所知,您希望创建基于统计信息的线程。这意味着,每个线程计算不同的统计信息。这看起来并没有通过设计进行优化。我的意思是,某些统计信息可能会比其他统计信息计算得更快。这样,您在这些线程中的处理能力就不会得到利用。
如你所知,Spark是一个分布式计算系统,它为你执行所有的并行性。我非常怀疑你是否能用Python的工具来超越Spark的优化。如果我们能做到这一点,它早就集成到Spark中了。
第一种方法写得非常好:基于数据类型的条件语句,包含isnan,类型提示--做得很好。它可能会尽可能地执行最好的结果,它 * 绝对写得很高效 *。最大的缺点是它将在整个 Dataframe 上运行,但你真的无法逃避这一点。关于列的数量,你不应该担心。整个select语句将非常长,但这只是一个操作。逻辑/物理计划应该是有效的。在最坏的情况下,你可以在这个操作之前持久化/缓存 Dataframe ,因为如果这个 Dataframe 是用一些复杂的代码创建的,你可能会遇到问题。但除此之外,你应该没事。
对于某些统计信息,您可以考虑使用summary作为替代:

df.summary().show()

# +-------+-----+------------------+

# |summary|    T|                 C|

# +-------+-----+------------------+

# |  count|    3|                 3|

# |   mean| null|43.666666666666664|

# | stddev| null| 32.34707611722168|

# |    min|Alice|              18.0|

# |    25%| null|              18.0|

# |    50%| null|              33.0|

# |    75%| null|              80.0|

# |    max| Mark|              80.0|

# +-------+-----+------------------+

这种方法只适用于数字列和字符串列。日期/时间戳列(例如“D”)会被自动排除。但我不确定这样是否更有效。而且肯定会不太清楚,因为它会在代码中添加额外的逻辑,而现在代码是相当简单的。

相关问题