如何在PySpark中使用Window()计算滚动中值?

s1ag04yj  于 2023-02-15  发布在  Spark
关注(0)|答案(2)|浏览(211)

如何计算前3个值的窗口大小的美元滚动中值?

    • 输入数据**
dollars timestampGMT       
25      2017-03-18 11:27:18
17      2017-03-18 11:27:19
13      2017-03-18 11:27:20
27      2017-03-18 11:27:21
13      2017-03-18 11:27:22
43      2017-03-18 11:27:23
12      2017-03-18 11:27:24
    • 预期输出数据**
dollars timestampGMT          rolling_median_dollar
25      2017-03-18 11:27:18   median(25)
17      2017-03-18 11:27:19   median(17,25)
13      2017-03-18 11:27:20   median(13,17,25)
27      2017-03-18 11:27:21   median(27,13,17)
13      2017-03-18 11:27:22   median(13,27,13)
43      2017-03-18 11:27:23   median(43,13,27)
12      2017-03-18 11:27:24   median(12,43,13)

下面的代码可以移动平均值,但是PySpark没有F. median()。
pyspark: rolling average using timeseries data
编辑1:挑战是median()函数不退出。我无法执行

df = df.withColumn('rolling_average', F.median("dollars").over(w))

如果我想要移动平均线的话

df = df.withColumn('rolling_average', F.avg("dollars").over(w))

编辑2:尝试使用approxQuantile()

windfun = Window().partitionBy().orderBy(F.col(date_column)).rowsBetwe‌​en(-3, 0) sdf.withColumn("movingMedian", sdf.approxQuantile(col='a', probabilities=[0.5], relativeError=0.00001).over(windfun))

但是出错了

AttributeError: 'list' object has no attribute 'over'

编辑3
请给出不含Udf的解决方案,因为它不会从催化剂优化中受益。

t2a7ltrp

t2a7ltrp1#

一种方法是将$dollars列收集为每个窗口的列表,然后使用udf

from pyspark.sql.window import Window
from pyspark.sql.functions import *
import numpy as np 
from pyspark.sql.types import FloatType

w = (Window.orderBy(col("timestampGMT").cast('long')).rangeBetween(-2, 0))
median_udf = udf(lambda x: float(np.median(x)), FloatType())

df.withColumn("list", collect_list("dollars").over(w)) \
  .withColumn("rolling_median", median_udf("list")).show(truncate = False)
+-------+---------------------+------------+--------------+
|dollars|timestampGMT         |list        |rolling_median|
+-------+---------------------+------------+--------------+
|25     |2017-03-18 11:27:18.0|[25]        |25.0          |
|17     |2017-03-18 11:27:19.0|[25, 17]    |21.0          |
|13     |2017-03-18 11:27:20.0|[25, 17, 13]|17.0          |
|27     |2017-03-18 11:27:21.0|[17, 13, 27]|17.0          |
|13     |2017-03-18 11:27:22.0|[13, 27, 13]|13.0          |
|43     |2017-03-18 11:27:23.0|[27, 13, 43]|27.0          |
|12     |2017-03-18 11:27:24.0|[13, 43, 12]|13.0          |
+-------+---------------------+------------+--------------+
k5hmc34c

k5hmc34c2#

另一种不使用任何udf的方法是使用pyspark.sql.functions中的expr

dict = [{'dollars': 25,'timestampGMT': '2017-03-18 11:27:18'},
        {'dollars': 17,'timestampGMT': '2017-03-18 11:27:19'},
        {'dollars': 13,'timestampGMT': '2017-03-18 11:27:20'},
        {'dollars': 27,'timestampGMT': '2017-03-18 11:27:21'},
        {'dollars': 13,'timestampGMT': '2017-03-18 11:27:22'},
        {'dollars': 43,'timestampGMT': '2017-03-18 11:27:23'},
        {'dollars': 12,'timestampGMT': '2017-03-18 11:27:24'}
       ]

test = spark.createDataFrame(dict,schema=['dollars','timestampGMT'])

test.withColumn("id", F.lit(1)).withColumn(
    "rolling_median_dollar",
    F.expr("percentile(dollars,0.5)").over(
        W.partitionBy("id")
        .orderBy(F.col("timestampGMT").cast("long"))
        .rowsBetween(-2, 0)
    ),
).drop('id').show()

+-------+-------------------+---------------------+
|dollars|       timestampGMT|rolling_median_dollar|
+-------+-------------------+---------------------+
|     25|2017-03-18 11:27:18|                 25.0|
|     17|2017-03-18 11:27:19|                 21.0|
|     13|2017-03-18 11:27:20|                 17.0|
|     27|2017-03-18 11:27:21|                 17.0|
|     13|2017-03-18 11:27:22|                 13.0|
|     43|2017-03-18 11:27:23|                 27.0|
|     12|2017-03-18 11:27:24|                 13.0|
+-------+-------------------+---------------------+

相关问题