如何使用pyspark有效地只保留那些满足特定组过滤器的 Dataframe 组?

zqdjd7g9  于 2023-02-05  发布在  Apache
关注(0)|答案(1)|浏览(159)

让我们使用以下虚拟数据:

df = spark.createDataFrame([(1,2),(1,3),(1,40),(1,0),(2,3),(2,1),(2,4),(3,2),(3,4)],['a','b'])
df.show()
+---+---+
|  a|  b|
+---+---+
|  1|  2|
|  1|  3|
|  1| 40|
|  1|  0|
|  2|  3|
|  2|  1|
|  2|  4|
|  3|  2|
|  3|  4|
+---+---+

1.如何过滤出平均值(b)不大于6的数据组
预期产出:

+---+---+
|  a|  b|
+---+---+
|  1|  2|
|  1|  3|
|  1| 40|
|  1|  0|
+---+---+

我如何实现目标:

df_filter = df.groupby('a').agg(F.mean(F.col('b')).alias("avg"))

df_filter = df_filter.filter(F.col('avg') > 6.)

df.join(df_filter,'a','inner').drop('avg').show()

问题:

  1. shuffle发生两次,一次用于计算df_filter,另一次用于连接。
df_filter = df.groupby('a').agg(F.mean(F.col('b')).alias("avg"))

df_filter = df_filter.filter(F.col('avg') > 6.)

df.join(df_filter,'a','inner').drop('avg').explain()
== Physical Plan ==
*(5) Project [a#175L, b#176L]
+- *(5) SortMergeJoin [a#175L], [a#222L], Inner
   :- *(2) Sort [a#175L ASC NULLS FIRST], false, 0
   :  +- Exchange hashpartitioning(a#175L, 200), ENSURE_REQUIREMENTS, [plan_id=919]
   :     +- *(1) Filter isnotnull(a#175L)
   :        +- *(1) Scan ExistingRDD[a#175L,b#176L]
   +- *(4) Sort [a#222L ASC NULLS FIRST], false, 0
      +- *(4) Project [a#222L]
         +- *(4) Filter (isnotnull(avg#219) AND (avg#219 > 6.0))
            +- *(4) HashAggregate(keys=[a#222L], functions=[avg(b#223L)])
               +- Exchange hashpartitioning(a#222L, 200), ENSURE_REQUIREMENTS, [plan_id=925]
                  +- *(3) HashAggregate(keys=[a#222L], functions=[partial_avg(b#223L)])
                     +- *(3) Filter isnotnull(a#222L)
                        +- *(3) Scan ExistingRDD[a#222L,b#223L]

如果我考虑一下,我应该只在键a上 Shuffle 一次,然后不再需要 Shuffle ,因为每个分区都是自给自足的。
问题:一般来说,排除不满足组相关筛选器的数据组的有效方法是什么?

yh2wf1be

yh2wf1be1#

您可以使用Window功能来代替groupBy + join,

out = df.withColumn("avg", avg(col("b")).over(Window.partitionBy("a")))\
    .where("avg>6").drop("avg")

out.explain()
out.show()

+- Project [a#0L, b#1L]
   +- Filter (isnotnull(avg#5) AND (avg#5 > 6.0))
      +- Window [avg(b#1L) windowspecdefinition(a#0L, specifiedwindowframe(RowFrame, unboundedpreceding$(), unboundedfollowing$())) AS avg#5], [a#0L]
         +- Sort [a#0L ASC NULLS FIRST], false, 0
            +- Exchange hashpartitioning(a#0L, 200), ENSURE_REQUIREMENTS, [plan_id=16]
               +- Scan ExistingRDD[a#0L,b#1L]

+---+---+
|  a|  b|
+---+---+
|  1|  2|
|  1|  3|
|  1| 40|
|  1|  0|
+---+---+

相关问题