PySpark聚合操作,用于对MapType(*,IntegerType())类型DataFrame列中的所有行求和

xpcnnkqh  于 2022-12-03  发布在  Spark
关注(0)|答案(2)|浏览(117)

假设您创建了一个带有精确模式的Spark DataFrame:

import pyspark.sql.functions as sf
from pyspark.sql.types import *

dfschema = StructType([
    StructField("_1", ArrayType(IntegerType())),
    StructField("_2", ArrayType(IntegerType())),
])
df = spark.createDataFrame([[[1, 2, 5], [13, 74, 1]], 
                            [[1, 2, 3], [77, 23, 15]]
                           
                           ], schema=dfschema)
df = df.select(sf.map_from_arrays("_1", "_2").alias("omap"))
df = df.withColumn("id", sf.lit(1))

上面的DataFrame如下所示:

+---------------------------+---+
|omap                       |id |
+---------------------------+---+
|{1 -> 13, 2 -> 74, 5 -> 1} |1  |
|{1 -> 77, 2 -> 23, 3 -> 15}|1  |
+---------------------------+---+

我想执行以下操作:

df.groupby("id").agg(sum_counter("omap")).show(truncate=False)

您能否帮助我定义一个sum_counter函数,该函数仅使用pyspark.sql.functions中的SQL函数(因此没有UDF),从而允许我在输出中获得这样一个DataFrame:

+---+-----------------------------------+
|id |mapsum                             |
+---+-----------------------------------+
|1  |{1 -> 90, 2 -> 97, 5 -> 1, 3 -> 15}|
+---+-----------------------------------+

我可以使用applyInPandas来解决这个问题:

from pyspark.sql.types import *
from collections import Counter
import pandas as pd

reschema = StructType([
    StructField("id", LongType()),
    StructField("mapsum", MapType(IntegerType(), IntegerType()))
])

def sum_counter(key: int, pdf: pd.DataFrame) -> pd.DataFrame:
    return pd.DataFrame([
        key
        + (sum([Counter(x) for x in pdf["omap"]], Counter()), )
    ])

df.groupby("id").applyInPandas(sum_counter, reschema).show(truncate=False)

+---+-----------------------------------+
|id |mapsum                             |
+---+-----------------------------------+
|1  |{1 -> 90, 2 -> 97, 5 -> 1, 3 -> 15}|
+---+-----------------------------------+

但是,出于性能原因,我希望避免使用applyInPandasUDFs

ssm49v7z

ssm49v7z1#

您可以先将omap分解为单独的行,其中key和value将设置在单独的列中,然后按如下方式聚合它们:

exploded_df = df.select("*", sf.explode("omap"))
agg_df = exploded_df.groupBy("id", "key").sum("value")
agg_df.groupBy("id").agg(sf.map_from_entries(sf.collect_list(sf.struct("key","sum(value)"))).alias("mapsum")).show(truncate=False)

+---+-----------------------------------+
|id |mapsum                             |
+---+-----------------------------------+
|1  |{2 -> 97, 1 -> 90, 5 -> 1, 3 -> 15}|
+---+-----------------------------------+
brccelvz

brccelvz2#

最后我这样解决了它:

import pyspark.sql.functions as sf

def sum_counter(mapcoln: str):
    dkeys = sf.array_distinct(sf.flatten(sf.collect_list(sf.map_keys(mapcoln))))
    dkeyscount = sf.transform(
        dkeys,
        lambda ukey: sf.aggregate(
            sf.collect_list(mapcoln),
            sf.lit(0),
            lambda acc, mapentry: sf.when(
                ~sf.isnull(sf.element_at(mapentry, ukey)),
                acc + sf.element_at(mapentry, ukey),
            ).otherwise(acc),
        ),
    )
    return sf.map_from_arrays(dkeys, dkeyscount).alias("mapsum")

df.groupby("id").agg(sum_counter("omap")).show(truncate=False)

+---+-----------------------------------+
|id |mapsum                             |
+---+-----------------------------------+
|1  |{1 -> 90, 2 -> 97, 5 -> 1, 3 -> 15}|
+---+-----------------------------------+

相关问题