Scala/Spark;向DataFrame添加列,当某个值在另一列中重复时,该列递增1

yquaqz18  于 2023-02-04  发布在  Scala
关注(0)|答案(3)|浏览(234)

我有一个叫做rawEDI的 Dataframe ,看起来像这样
| 行号|节段|
| - ------|- ------|
| 1个|标准时间|
| 第二章|最佳点|
| 三个|标准误|
| 四个|标准时间|
| 五个|最佳点|
| 六个|N1|
| 七|标准误|
| 八个|标准时间|
| 九|PTD|
| 十个|标准误|
每一行代表一个文件中的一行,每一行被称为一个段,用一个段标识符来表示;一个短字符串。段被分组在以ST段标识符开始并以SE段标识符结束的组块中。给定文件中可以有任意数量的ST组块,并且每个ST组块的大小不固定。
我想在 Dataframe 上创建一个新列,用数字表示给定段所属的ST组,这样我就可以使用groupBy对所有ST段执行聚合操作,而不必在每个单独的ST段上循环,因为这样太慢了。
最终的DataFrame将如下所示;
| 行号|节段|ST_集团|
| - ------|- ------|- ------|
| 1个|标准时间|1个|
| 第二章|最佳点|1个|
| 三个|标准误|1个|
| 四个|标准时间|第二章|
| 五个|最佳点|第二章|
| 六个|N1|第二章|
| 七|标准误|第二章|
| 八个|标准时间|三个|
| 九|PTD|三个|
| 十个|标准误|三个|
简而言之,我想创建一个DataFrame列,并填充一个数字,每当值"ST"出现在Segment列中时,该数字就递增1。
我使用的是spark 2.3.2和scala 2.11.8
我最初的想法是使用迭代,我收集了另一个DataFrame df,它包含每个片段的起始和结束line_number,看起来像这样;
| 开始|完|
| - ------|- ------|
| 1个|三个|
| 四个|七|
| 八个|十个|
然后迭代 Dataframe 的行,并使用它们填充新列,如下所示;

var st = 1
for (row <- df.collect()) {
    val start = row(0)
    val end  = row(1)
    var labelSTs = rawEDI.filter("line_number > = ${start}").filter("line_number <= ${end}").withColumn("ST_Group", lit(st))
    st = st + 1

然而,这会产生一个空的DataFrame。另外,使用for循环是非常耗时的,在我的机器上要花费20秒以上的时间。不使用循环来实现这个结果将是巨大的,但是如果性能良好,使用循环的解决方案也是可以接受的。
我有一种预感,这可以使用udf或窗口来完成,但我不确定如何攻击它。
这个

val func = udf((s:String) => if(s == "ST") 1 else 0)
var labelSTs = rawEDI.withColumn("ST_Group", func((col("segment")))

仅在每个ST段开始时用1填充列。
还有这个

val w = Window.partitionBy("Segment").orderBy("line_number")
val labelSTs = rawEDI.withColumn("ST_Group", row_number().over(w)

返回无意义 Dataframe 。

yshpjwxd

yshpjwxd1#

一种方法是创建一个"组"的中间 Dataframe ,告诉您每个组在哪一行开始和结束(类似于您已经做过的),然后使用大于/小于条件将其连接到原始表。
样本数据

scala> val input = Seq((1,"ST"),(2,"BPT"),(3,"SE"),(4,"ST"),(5,"BPT"),
                       (6,"N1"),(7,"SE"),(8,"ST"),(9,"PTD"),(10,"SE"))
                   .toDF("linenumber","segment")
scala> input.show(false)
+----------+-------+
|linenumber|segment|
+----------+-------+
|1         |ST     |
|2         |BPT    |
|3         |SE     |
|4         |ST     |
|5         |BPT    |
|6         |N1     |
|7         |SE     |
|8         |ST     |
|9         |PTD    |
|10        |SE     |
+----------+-------+

使用Window为组创建一个 Dataframe ,就像您的直觉告诉您的那样:

scala> val groups = input.where("segment='ST'")
                    .withColumn("endline",lead("linenumber",1) over Window.orderBy("linenumber"))
                    .withColumn("groupnumber",row_number() over Window.orderBy("linenumber"))
                    .withColumnRenamed("linenumber","startline")
                    .drop("segment")

scala> groups.show(false)
+---------+-----------+-------+
|startline|groupnumber|endline|
+---------+-----------+-------+
|1        |1          |4      |
|4        |2          |8      |
|8        |3          |null   |
+---------+-----------+-------+

将两者连接以获得结果

scala> input.join(groups,
                  input("linenumber") >= groups("startline") && 
                  (input("linenumber") < groups("endline") || groups("endline").isNull))
            .select("linenumber","segment","groupnumber") 
            .show(false)
+----------+-------+-----------+
|linenumber|segment|groupnumber|
+----------+-------+-----------+
|1         |ST     |1          |
|2         |BPT    |1          |
|3         |SE     |1          |
|4         |ST     |2          |
|5         |BPT    |2          |
|6         |N1     |2          |
|7         |SE     |2          |
|8         |ST     |3          |
|9         |PTD    |3          |
|10        |SE     |3          |
+----------+-------+-----------+

这样做的唯一问题是Window.orderBy()在未分区的 Dataframe 上,它会将所有数据收集到一个分区,因此可能是一个杀手。

k4aesqcs

k4aesqcs2#

如果您只想添加一个列,其编号在值“ST”出现在Segment列中时递增1,您可以在单独的 Dataframe 中过滤具有ST段的行,

var labelSTs = rawEDI.filter("segement == 'ST'");
// then group by ST and collect to list the linenumbers
var groupedDf = labelSTs.groupBy("Segment").agg(collect_list("Line_number").alias("Line_numbers"))
// now you need to flat back the data frame and log the line number index 
var flattedDf = groupedDf.select($"Segment", explode($"Line_numbers").as("Line_number"))   
// log the line_number index in your target column ST_Group
val withIndexDF = flattenedDF.withColumn("ST_Group", row_number().over(Window.partitionBy($"Segment").orderBy($"Line_number")))

结果是这样的:

+-------+----------+----------------+
|Segment|Line_number|ST_Group       |
+-------+----------+----------------+
|     ST|         1|               1|
|     ST|         4|               2|
|     ST|         8|               3|
+-------|----------|----------------|

然后将其与初始 Dataframe 中的其他段连接。

z0qdvdin

z0qdvdin3#

找到了一个更简单的方法,添加一列,当段列值为ST时,该列的值为1,否则该列的值为0。然后使用Window函数查找该新列的累积和。这将给您所需的结果。

import org.apache.spark.sql.functions._
import org.apache.spark.sql.expressions.Window
val rawEDI=Seq((1,"ST"),(2,"BPT"),(3,"SE"),(4,"ST"),(5,"BPT"),(6,"N1"),(7,"SE"),(8,"ST"),(9,"PTD"),(10,"SE")).toDF("line_number","segment")

val newDf=rawEDI.withColumn("ST_Group", ($"segment" === "ST").cast("bigint"))

val windowSpec = Window.orderBy("line_number")
newDf.withColumn("ST_Group", sum("ST_Group").over(windowSpec))
.show
+-----------+-------+--------+
|line_number|segment|ST_Group|
+-----------+-------+--------+
|          1|     ST|       1|
|          2|    BPT|       1|
|          3|     SE|       1|
|          4|     ST|       2|
|          5|    BPT|       2|
|          6|     N1|       2|
|          7|     SE|       2|
|          8|     ST|       3|
|          9|    PTD|       3|
|         10|     SE|       3|
+-----------+-------+--------+

相关问题