spark:如何将行分组到一个固定大小的数组中?

3duebb1j  于 2021-05-17  发布在  Spark
关注(0)|答案(1)|浏览(482)

我的数据集如下所示:

+---+
|col|
+---+
|  a|
|  b|
|  c|
|  d|
|  e|
|  f|
|  g|
+---+

我想重新格式化此数据集,以便将行聚合为固定长度的数组,如下所示:

+------+
|   col|
+------+
|[a, b]|
|[c, d]|
|[e, f]|
|   [g]|
+------+

我试过这个: spark.sql("select collect_list(col) from (select col, row_number() over (order by col) row_number from dataset) group by floor(row_number/2)") 但问题是,我的实际数据集太大,无法在一个分区中处理行号()

n3schb8v

n3schb8v1#

当您希望分发此文件时,有几个步骤是必要的。
如果您希望运行代码,我将从以下内容开始:

var df = List(
  "a", "b", "c", "d", "e", "f", "g"
).toDF("col")
val desiredArrayLength = 2

首先,将tyour的Dataframe拆分为一个小的Dataframe,您可以在单个节点上处理,而大的Dataframe的行数是所需数组大小的倍数(在您的示例中,这是2)

val nRowsPrune = 1 //number of rows to prune such that remaining dataframe has number of
                   // rows is multiples of the desired length of array
val dfPrune = df.sort(desc("col")).limit(nRowsPrune)
df = df.join(dfPrune,Seq("col"),"left_anti") //separate small from large dataframe

通过构造,您可以将原始代码应用于小Dataframe,

val groupedPruneDf = dfPrune//.withColumn("g",floor((lit(-1)+row_number().over(w))/lit(desiredArrayLength ))) //added -1 as row-number starts from 1
                            //.groupBy("g")
                            .agg( collect_list("col").alias("col"))
                            .select("col")

现在,我们需要找到一种方法来处理剩余的大型Dataframe。但是,现在我们确定df有很多行,是数组大小的倍数。在这里我们使用了一个很好的技巧,即使用 repartitionByRange . 基本上,分区保证保留排序,并且当您进行分区时,每个分区将具有相同的大小。现在,可以收集每个分区中的每个数组,

val nRows = df.count()
   val maxNRowsPartition = desiredArrayLength //make sure its a multiple of desired array length
   val nPartitions = math.max(1,math.floor(nRows/maxNRowsPartition) ).toInt
   df = df.repartitionByRange(nPartitions, $"col".desc)
          .withColumn("partitionId",spark_partition_id())

    val w = Window.partitionBy($"partitionId").orderBy("col")
    val groupedDf = df
        .withColumn("g",  floor( (lit(-1)+row_number().over(w))/lit(desiredArrayLength ))) //added -1 as row-number starts from 1
        .groupBy("partitionId","g")
        .agg( collect_list("col").alias("col"))
        .select("col")

最后把这两个结果结合起来就得到了你想要的结果,

val result = groupedDf.union(groupedPruneDf)
result.show(truncate=false)

相关问题