在pyspark中按元素对数组列求和-仅使用spark函数

oprakyz7  于 2023-10-15  发布在  Spark
关注(0)|答案(3)|浏览(129)

我想按元素对一列数组中的数组求和-这列数组应该聚合为一个数组。下面的代码给出了所需的结果[3,6,9],但它使用了一个UDF,在缩放时会导致OOM。我希望有同样的结果,但纯粹在Spark!

import pyspark.sql.functions as F
from pyspark.sql.types import IntegerType,ArrayType

def sum_arr_by_element(l):
    return [sum(x) for x in zip(*l)]

my_udf = udf(sum_arr_by_element, ArrayType(IntegerType()))

data = [
    ([1, 2, 3],),
    ([1, 2, 3],),
    ([1, 2, 3],),
]

df = spark.createDataFrame(data, ["array_column"])

df.agg(F.collect_list("array_column").alias('all_lists')).withColumn('summed',my_udf("all_lists")).select('summed').display()
db2dz4w8

db2dz4w81#

试试这个:

from pyspark.sql import SparkSession
import pyspark.sql.functions as F
from pyspark.sql.types import IntegerType, ArrayType

spark = SparkSession.builder.appName("toto").getOrCreate()

data = [
    ([1, 2, 3],),
    ([1, 2, 3],),
    ([1, 2, 3],),
]

df = spark.createDataFrame(data, ["array_column"])

df = df.withColumn("id", F.monotonically_increasing_id())
df = df.selectExpr("id", "posexplode(array_column) as (pos, value)")
result_df = df.groupBy("pos").agg(F.sum("value").alias("summed_value"))
result_array = result_df.orderBy("pos").select("summed_value").rdd.flatMap(lambda x: x).collect()

print(result_array)
ovfsdjhp

ovfsdjhp2#

这里有一种方法,首先按行聚合数组的元素,然后按元素在原始列表中的位置顺序收集元素

df_result = (
    df
    .select(F.posexplode('array_column'))
    .groupby('pos')
    .agg(F.expr("struct(pos, sum(col) as col) as pos_col"))
    .agg(F.expr("transform(array_sort(collect_list(pos_col)), x -> x.col) as array_column"))
)
+------------+
|array_column|
+------------+
|   [3, 6, 9]|
+------------+
jpfvwuh4

jpfvwuh43#

你可以在Spark中使用阵列的最大功率。
如果你有一个分组列,收集组内的所有数组,并使用aggregate函数。该函数将保留元素位置。
这里有一个例子

data_sdf. \
    groupBy(func.lit(1).alias('group')). \
    agg(func.collect_list('arr_col').alias('all_arrs')). \
    withColumn('arr_sum',
               func.expr('''
                         aggregate(slice(all_arrs, 2, size(all_arrs)), 
                                   all_arrs[0], 
                                   (x, y) -> zip_with(x, y, (a, b) -> a+b)
                                   )
                         ''')
               ). \
    show(truncate=False)

# +-----+---------------------------------+---------+
# |group|all_arrs                         |arr_sum  |
# +-----+---------------------------------+---------+
# |1    |[[1, 2, 3], [1, 2, 3], [1, 2, 3]]|[3, 6, 9]|
# +-----+---------------------------------+---------+

相关问题