pyspark python spark固定行数的简单移动平均

deyfvvtc  于 2023-08-02  发布在  Spark
关注(0)|答案(2)|浏览(125)

我正在尝试在pyspark中实现SMA(简单移动平均线)。我面临的问题是我需要3行的SMA,但spark(在这个简单的例子中)给了我SMA 1,SMA 2,SMA 3,SMA 3。

import pyspark # 3.4.1
from pyspark.sql import SparkSession

spark = SparkSession.builder.getOrCreate()

df = spark.createDataFrame(
    [
        (1, 5),
        (2, 10),
        (3, 15),
        (4, 20)
    ],
    ('id1', 'v1')
)
df.createOrReplaceTempView("x")
spark.sql('select avg(v1) over (order by id1 rows between 2 preceding and current row) as v1 from x').collect()
#[Row(v1=5.0), Row(v1=7.5), Row(v1=10.0), Row(v1=15.0)]

字符串
当SMA不是SMA 3时,我怎么能有NULL?
换句话说,我的预期结果是v1 = NULL, NULL, 10.0, 15.0

ezykj2lf

ezykj2lf1#

你可以使用case语句来实现它

spark.sql('''
    SELECT 
        CASE 
            WHEN ROW_NUMBER() OVER (ORDER BY id1) <= 2 THEN NULL
            ELSE AVG(v1) OVER (ORDER BY id1 ROWS BETWEEN 2 PRECEDING AND CURRENT ROW)
        END AS v1
    FROM x
''').collect()

字符串
输出量:

[Row(v1=None), Row(v1=None), Row(v1=10.0), Row(v1=15.0)]

qv7cva1a

qv7cva1a2#

我建议使用LAG的SQL的其他方法:

spark.sql('''
SELECT (
       v1 +
       LAG(v1) OVER (order by id1) +
       LAG(v1, 2) OVER (order by id1)
       ) / 3 AS avg_v1
FROM x
''').collect()

字符串
然而,延长SMA窗口可能是令人头痛的。

相关问题