python 如何对pyspark Dataframe 中的分组记录应用条件?

h79rfbju  于 2023-03-11  发布在  Python
关注(0)|答案(1)|浏览(163)

我希望找到一个解决方案来检查一个组内的多个条件。首先我检查记录之间的重叠(基于ID),其次我应该为同一个传染性重叠中编号最高的记录例外。最重要的是,同一个ID可以有多个重叠。例如:

data = [('A',1000,1,100),
   ('B',1001,0,10),
   ('B',1002,10,15),
   ('B',1002,20,22),
   ('B',1003,25,50),
   ('B',1004,50,55),
   ('B',1005,53,56),
   ('B',1006,60,100),
   ('C',1007,1,100)
 ]

schema = StructType([ \
   StructField("id",StringType(),True), \
   StructField("tran",IntegerType(),True), \
   StructField("start",IntegerType(),True), \
   StructField("end",IntegerType(),True), \
 ])

df = spark.createDataFrame(data=data,schema=schema)
df.show()

+---+----+-----+---+
| id|tran|start|end|
+---+----+-----+---+
|  A|1000|    1|100|
|  B|1001|    0| 10|
|  B|1002|   10| 15|
|  B|1003|   20| 22|
|  B|1004|   25| 50|
|  B|1005|   50| 55|
|  B|1006|   53| 56|
|  B|1007|   60|100|
|  C|1008|    1|100|
+---+----+-----+---+

所需的 Dataframe 应如下所示:

| id|tran|start|end|valid|
+---+----+-----+---+-----+
|  A|1000|    1|100|  yes| # this is valid because by id there is no overlap between start and end
|  B|1001|    0| 10|   no| # invalid because by id it overlaps with the next
|  B|1002|   10| 15|  yes| # it overlaps with the previous one but it has the highest tran number between the two 
|  B|1003|   20| 22|  yes| # yes because no overlap
|  B|1004|   25| 50|   no| # invalid because overlaps and the tran is not the highest
|  B|1005|   50| 55|   no| # invalid because overlaps and the tran is not the highest
|  B|1006|   53| 56|  yes| # it overlaps with the previous ones but it has the highest tran number among the three contagiously overlapping ones
|  B|1007|   60|100|  yes| # no overlap
|  C|1008|    1|100|  yes| # no overlap
+---+----+-----+---+-----+

非常感谢解决这个问题的传奇人物:)

iqxoj9l9

iqxoj9l91#

1.导入必要的包

from pyspark.sql.functions import *
from pyspark.sql.window import Window

1.创建数据框

data = [('A',1000,1,100),
        ('B',1001,0,10),
        ('B',1002,10,15),
        ('B',1003,20,22), # The result and dataset mismatched, I have updated it here
        ('B',1004,25,50),
        ('B',1005,50,55),
        ('B',1006,53,56),
        ('B',1007,60,100),
        ('C',1008,1,100)]

df = spark.createDataFrame(data,"id STRING, tran INT, start INT, end INT")

1.添加一些额外的列,需要与其他记录进行比较

df = df.withColumn("row_id", row_number().over(Window.orderBy(lit(1))))

df = df.withColumn("lag_tran", lag("tran", 1, "null").over(Window.orderBy("tran")))

df = df.withColumn("lead_id", lead("id", 1, "null").over(Window.orderBy("tran")))

df = df.withColumn("lag_end", lag("end", 1, -1).over(Window.orderBy("tran")))

df = df.withColumn("lead_start", lead("start", 1, -1).over(Window.orderBy("tran")))

df = df.withColumn("valid", lit("no"))

1.让我们根据您指定的条件筛选出记录

# a. Filters out record which doesn't have overlap id
non_id_overlap = df.filter("id != lead_id").withColumn("valid", lit("yes"))

# b. Filters out record which have overlap id
id_overlap = df.filter("id == lead_id")

# c. Filters out record whose start and end doesn't overlap with the next or prev records
non_time_lap = id_overlap.filter("(start > lag_end) and (end < lead_start)").withColumn("valid", lit("yes"))

# d. Filters out record whose start and end doesn't overlap with the next or prev records
time_overlap = id_overlap.filter("(start <= lag_end) or (end >= lead_start)")

# e. Filters out record whose end overlaps with the next start
overlap_next = time_overlap.filter("end >= lead_start")

# f. Filters out record whose start overlaps with the prev end and has the higgest tran
overlap_prev = time_overlap.filter("not(end >= lead_start) and (tran > lag_tran)").withColumn("valid", lit("yes"))

1.让我们将所有匹配的DataFrame连接在一起

matched_df = spark.createDataFrame(sc.emptyRDD(), non_id_overlap.schema)

for df in [non_id_overlap,non_time_lap,overlap_next,overlap_prev]:
    matched_df = matched_df.unionByName(df)

1.我们来查找不匹配的记录

mismatched_df = df.drop("valid").subtract(matched_df.drop("valid")).withColumn("valid", lit("no"))

1.让我们将它们相加,排序,然后删除一些列以获得所需的输出

final_df = matched_df.unionByName(mismatched_df)
final_df = final_df.orderBy("row_id").drop("row_id","lead_id","lag_tran","lag_end","lead_start")
final_df.show()

这就是:

+---+----+-----+---+-----+
| id|tran|start|end|valid|
+---+----+-----+---+-----+
|  A|1000|    1|100|  yes|
|  B|1001|    0| 10|   no|
|  B|1002|   10| 15|  yes|
|  B|1003|   20| 22|  yes|
|  B|1004|   25| 50|   no|
|  B|1005|   50| 55|   no|
|  B|1006|   53| 56|  yes|
|  B|1007|   60|100|  yes|
|  C|1008|    1|100|  yes|
+---+----+-----+---+-----+

相关问题