在PySpark Dataframe 中找到一系列具有相同值的连续行

wvmv3b1j  于 2022-11-01  发布在  Spark
关注(0)|答案(1)|浏览(155)

我有一个按日期排序的 Dataframe 。每行包含一个“flag”列(值为1或0)。我希望找到“flag”值等于1的3个(或更多)连续行的序列。目标是如果它不是:flag”为1的3个或更多连续元素序列的一部分,则将“flag”值重置为0。
下面是一个数据示例:
| 日期|旗帜|
| - -|- -|
| 2022年1月1日|一个|
| 2022年1月2日|一个|
| 2022年1月3日|一个|
| 2022年1月4日|一个|
| 2022年1月5日|第0页|
| 2022年1月6日|第0页|
| 2022年1月7日|一个|
| 2022年1月8日|一个|
| 2022年1月9日|第0页|
| 2022年1月10日|一个|
我们只需要将前4行的值保留为1,因为它们构成了一个由4行(多于3行)组成的序列,标志中的值为1。
| 日期|旗帜|
| - -|- -|
| 2022年1月1日|一个|
| 2022年1月2日|一个|
| 2022年1月3日|一个|
| 2022年1月4日|一个|
| 2022年1月5日|第0页|
| 2022年1月6日|第0页|
| 2022年1月7日|第0页|
| 2022年1月8日|第0页|
| 2022年1月9日|第0页|
| 2022年1月10日|第0页|
我想,也许使用基于前一个元素的lag函数是有意义的,但不确定它在PySpark中的效率如何。

dgsult0t

dgsult0t1#

您将需要使用几个窗口函数。在3个不同的窗口中计数标志:-2:0-1:10:-2。如果这些中至少有一个的sum3,那么你有3个连续的1。
在下面的脚本中,我假设日期以 string 数据类型存储,因此我使用了列表达式date从字符串中读取真实日期。如果日期不是字符串格式,则不应使用该行。
输入:

from pyspark.sql import functions as F, Window as W
df = spark.createDataFrame(
    [('01-01-2022', 1),
     ('02-01-2022', 1),
     ('03-01-2022', 1),
     ('04-01-2022', 1),
     ('05-01-2022', 0),
     ('06-01-2022', 0),
     ('07-01-2022', 1),
     ('08-01-2022', 1),
     ('09-01-2022', 0),
     ('10-01-2022', 1)],
    ['date', 'flag'])

脚本:

date = F.to_date('date', 'dd-MM-yyyy')
cond1 = F.sum('flag').over(W.orderBy(date).rowsBetween(-2, 0)) == 3
cond2 = F.sum('flag').over(W.orderBy(date).rowsBetween(-1, 1)) == 3
cond3 = F.sum('flag').over(W.orderBy(date).rowsBetween(0, 2)) == 3
df = df.withColumn('flag', (cond1 | cond2 | cond3).cast('long'))

df.show()

# +----------+----+

# |      date|flag|

# +----------+----+

# |01-01-2022|   1|

# |02-01-2022|   1|

# |03-01-2022|   1|

# |04-01-2022|   1|

# |05-01-2022|   0|

# |06-01-2022|   0|

# |07-01-2022|   0|

# |08-01-2022|   0|

# |09-01-2022|   0|

# |10-01-2022|   0|

# +----------+----+

相关问题