我有下面的代码,我想在pyspark中将Pandas UDF重写为纯窗口函数,以优化速度
列cumulative_pass
是我想以编程方式创建的-
import pandas as pd
import pyspark.sql.functions as F
from pyspark.sql import SparkSession
from pyspark.sql import Window
import sys
spark_session = SparkSession.builder.getOrCreate()
df_data = {'username': ['bob','bob', 'bob', 'bob', 'bob', 'bob', 'bob', 'bob'],
'session': [1,2,3,4,5,6,7,8],
'year_start': [2020,2020,2020,2020,2020,2021,2022,2023],
'year_end': [2020,2020,2020,2020,2021,2021,2022,2023],
'pass': [1,0,0,0,0,1,1,0],
'cumulative_pass': [0,0,0,0,0,1,2,3],
}
df_pandas = pd.DataFrame.from_dict(df_data)
df = spark_session.createDataFrame(df_pandas)
df.show()
最后一个show
是这样的
+--------+-------+----------+--------+----+---------------+
|username|session|year_start|year_end|pass|cumulative_pass|
+--------+-------+----------+--------+----+---------------+
| bob| 1| 2020| 2020| 1| 0|
| bob| 2| 2020| 2020| 0| 0|
| bob| 3| 2020| 2020| 0| 0|
| bob| 4| 2020| 2020| 0| 0|
| bob| 5| 2020| 2021| 0| 0|
| bob| 6| 2021| 2021| 1| 1|
| bob| 7| 2022| 2022| 1| 2|
| bob| 8| 2023| 2023| 0| 3|
+--------+-------+----------+--------+----+---------------+
下面的代码可以工作,但速度很慢(UDF很慢)
def conditional_sum(data: pd.DataFrame) -> int:
df = data.apply(pd.Series)
return df.loc[df['year_start'].max() > df['year_end']]['pass'].sum()
udf_conditional_sum = F.pandas_udf(conditional_sum, IntegerType())
w = Window.partitionBy("username").orderBy(F.asc("year_start")).rowsBetween(-sys.maxsize, 0)
df = df.withColumn("calculate_cumulative_pass", udf_conditional_sum(F.struct("year_start", "year_end", "pass")).over(w))
注意-我稍微修改了w
,并删除了第二个排序
2条答案
按热度按时间vom3gejh1#
代码
如何工作
创建一个窗口规范,并收集前面所有行的
year_end
和pass
值对。当对中的year_end
小于当前行的year_start
时,聚合对和sum
对中的pass
值。结果
g2ieeal72#
这种方法类似于@Shubham。使用
F.transform
进行过滤,然后使用F.aggregate
进行聚合。输出量: