让我们使用以下虚拟数据:
df = spark.createDataFrame([(1,2),(1,3),(1,40),(1,0),(2,3),(2,1),(2,4),(3,2),(3,4)],['a','b'])
df.show()
+---+---+
| a| b|
+---+---+
| 1| 2|
| 1| 3|
| 1| 40|
| 1| 0|
| 2| 3|
| 2| 1|
| 2| 4|
| 3| 2|
| 3| 4|
+---+---+
1.如何过滤出平均值(b)不大于6的数据组
预期产出:
+---+---+
| a| b|
+---+---+
| 1| 2|
| 1| 3|
| 1| 40|
| 1| 0|
+---+---+
我如何实现目标:
df_filter = df.groupby('a').agg(F.mean(F.col('b')).alias("avg"))
df_filter = df_filter.filter(F.col('avg') > 6.)
df.join(df_filter,'a','inner').drop('avg').show()
问题:
- shuffle发生两次,一次用于计算df_filter,另一次用于连接。
df_filter = df.groupby('a').agg(F.mean(F.col('b')).alias("avg"))
df_filter = df_filter.filter(F.col('avg') > 6.)
df.join(df_filter,'a','inner').drop('avg').explain()
== Physical Plan ==
*(5) Project [a#175L, b#176L]
+- *(5) SortMergeJoin [a#175L], [a#222L], Inner
:- *(2) Sort [a#175L ASC NULLS FIRST], false, 0
: +- Exchange hashpartitioning(a#175L, 200), ENSURE_REQUIREMENTS, [plan_id=919]
: +- *(1) Filter isnotnull(a#175L)
: +- *(1) Scan ExistingRDD[a#175L,b#176L]
+- *(4) Sort [a#222L ASC NULLS FIRST], false, 0
+- *(4) Project [a#222L]
+- *(4) Filter (isnotnull(avg#219) AND (avg#219 > 6.0))
+- *(4) HashAggregate(keys=[a#222L], functions=[avg(b#223L)])
+- Exchange hashpartitioning(a#222L, 200), ENSURE_REQUIREMENTS, [plan_id=925]
+- *(3) HashAggregate(keys=[a#222L], functions=[partial_avg(b#223L)])
+- *(3) Filter isnotnull(a#222L)
+- *(3) Scan ExistingRDD[a#222L,b#223L]
如果我考虑一下,我应该只在键a
上 Shuffle 一次,然后不再需要 Shuffle ,因为每个分区都是自给自足的。
问题:一般来说,排除不满足组相关筛选器的数据组的有效方法是什么?
1条答案
按热度按时间yh2wf1be1#
您可以使用
Window
功能来代替groupBy + join,