使用dict优化列上的PySpark Dataframe聚合

2sbarzqh  于 2023-11-21  发布在  Apache
关注(0)|答案(1)|浏览(106)

尝试处理具有以下结构的DataFrame:
| 名称|国家|选项|
| --|--|--|
| 虚拟1| [“UK”,“FR”]|[{“id”:1,“key1”:10,“key2”:20},{“id”:2,“key1”:10,“key2”:20}]|
| 虚拟1| [“FR”]|[{“id”:1,“key1”:20,“key2”:30}]|
| Dummy2| [“UK”,“FR”]|[{“id”:1,“key1”:10,“key2”:20}]|
变成这样
| 名称|国家|选项|
| --|--|--|
| 虚拟1| [“UK”,“FR”]|[{“id”:1,“key1”:20,“key2”:30},{“id”:2,“key1”:10,“key2”:20}]|
| Dummy2| [“UK”,“FR”]|[{“id”:1,“key1”:10,“key2”:20}]|
因此,基本上的想法是按name进行分组,得到countries的不同列表和选项列表,其中,如果对id进行分组并且仅保留最高的key1(或者如果key1相同,则为key2),
我能够通过这样的国家聚合来解决这个问题

countries = df.groupBy(col("name") \ 
              .agg(array_distinct(flatten(collect_list(col("countries")))

字符串
并将其加入选项聚合

window = Window().partitionBy("name").orderBy(col("key1").desc(), col("key2").desc()) 
options = df.select("name", explode("options").alias("options")) \ 
            .withColumn("id", col("options.id") \ 
            .withColumn("key1", col("options.key1") \ 
            .withColumn("key2", col("options.key2") \ 
            .withColumn("rn", row_number().over(window)) \ 
            .filter("rn = 1") \ 
            .groupBy(col("name")) \ 
            .agg(collect_list(options))


然而,这是相当慢的给予的数据量处理。我会猜测,试图使它在一次过,没有连接会更快,但我实现的方式也相当慢。
你看到任何优化我可以做我的代码,使它运行得更快?谢谢

gdrx4gfi

gdrx4gfi1#

一气呵成:对于获取唯一的国家,可以使用collect_set函数与单独的窗口,并在以后进行扁平化。在Scala上(猜测,可以轻松转换为Python):

// case class Keys(id: Int, key1: Int, key2: Int)
val df = Seq(
  ("dummy1", Seq("UK", "FR"), Seq(
    Keys(1, 10, 20),
    Keys(2, 10, 20)
  )),
  ("dummy1", Seq("FR"), Seq(
    Keys(1, 20, 30)
  )),
  ("dummy2", Seq("UK", "FR"), Seq(
    Keys(1, 10, 20)
  ))
).toDF("name", "countries", "options")

val nameIdWindow = Window.partitionBy("name", "option.id")
  .orderBy(desc("option.key1"), desc("option.key2"))

val nameWindow = Window.partitionBy("name")

val result = df
  .select(
    $"name",
    collect_set($"countries").over(nameWindow).alias("countries_array_arrays"),
    explode($"options").alias("option")
  )
  .withColumn("rn", row_number().over(nameIdWindow))
  .filter("rn = 1")
  .withColumn("countries", array_distinct(flatten($"countries_array_arrays")))
  .groupBy($"name", $"countries")
  .agg(collect_set("option").alias("options"))

字符串
测试结果:

+------+---------+--------------------------+
|name  |countries|options                   |
+------+---------+--------------------------+
|dummy1|[FR, UK] |[{1, 20, 30}, {2, 10, 20}]|
|dummy2|[UK, FR] |[{1, 10, 20}]             |
+------+---------+--------------------------+

相关问题