将pandPandasUDF重写为纯pyspark

qvtsj1bj  于 2023-10-15  发布在  Spark
关注(0)|答案(2)|浏览(76)

我有下面的代码,我想在pyspark中将Pandas UDF重写为纯窗口函数,以优化速度
cumulative_pass是我想以编程方式创建的-

import pandas as pd
import pyspark.sql.functions as F
from pyspark.sql import SparkSession
from pyspark.sql import Window
import sys 

spark_session = SparkSession.builder.getOrCreate()
 

df_data = {'username': ['bob','bob', 'bob', 'bob', 'bob', 'bob', 'bob', 'bob'],
           'session': [1,2,3,4,5,6,7,8],
           'year_start': [2020,2020,2020,2020,2020,2021,2022,2023],
           'year_end': [2020,2020,2020,2020,2021,2021,2022,2023],
           'pass': [1,0,0,0,0,1,1,0],
           'cumulative_pass': [0,0,0,0,0,1,2,3],
          }
df_pandas = pd.DataFrame.from_dict(df_data)

df = spark_session.createDataFrame(df_pandas)
df.show()

最后一个show是这样的

+--------+-------+----------+--------+----+---------------+
|username|session|year_start|year_end|pass|cumulative_pass|
+--------+-------+----------+--------+----+---------------+
|     bob|      1|      2020|    2020|   1|              0|
|     bob|      2|      2020|    2020|   0|              0|
|     bob|      3|      2020|    2020|   0|              0|
|     bob|      4|      2020|    2020|   0|              0|
|     bob|      5|      2020|    2021|   0|              0|
|     bob|      6|      2021|    2021|   1|              1|
|     bob|      7|      2022|    2022|   1|              2|
|     bob|      8|      2023|    2023|   0|              3|
+--------+-------+----------+--------+----+---------------+

下面的代码可以工作,但速度很慢(UDF很慢)

def conditional_sum(data: pd.DataFrame) -> int:
   df = data.apply(pd.Series)

    return df.loc[df['year_start'].max() > df['year_end']]['pass'].sum()

udf_conditional_sum = F.pandas_udf(conditional_sum, IntegerType())

w = Window.partitionBy("username").orderBy(F.asc("year_start")).rowsBetween(-sys.maxsize, 0)
df = df.withColumn("calculate_cumulative_pass", udf_conditional_sum(F.struct("year_start", "year_end", "pass")).over(w))

注意-我稍微修改了w,并删除了第二个排序

vom3gejh

vom3gejh1#

代码

W = Window.partitionBy('username').orderBy('year_start')
df = (
    df
    .withColumn('cumulative_pass',  F.collect_list(F.struct('year_end', 'pass')).over(W))
    .withColumn('cumulative_pass',  F.expr("AGGREGATE(cumulative_pass, 0, (acc, x) -> CAST(acc + IF(x['year_end'] < year_start, x['pass'], 0) AS INT))"))
)

如何工作

创建一个窗口规范,并收集前面所有行的year_endpass值对。当对中的year_end小于当前行的year_start时,聚合对和sum对中的pass值。

结果
+--------+-------+----------+--------+----+---------------+
|username|session|year_start|year_end|pass|cumulative_pass|
+--------+-------+----------+--------+----+---------------+
|bob     |1      |2020      |2020    |1   |0              |
|bob     |2      |2020      |2020    |0   |0              |
|bob     |3      |2020      |2020    |0   |0              |
|bob     |4      |2020      |2020    |0   |0              |
|bob     |5      |2020      |2021    |0   |0              |
|bob     |6      |2021      |2021    |1   |1              |
|bob     |7      |2022      |2022    |1   |2              |
|bob     |8      |2023      |2023    |0   |3              |
+--------+-------+----------+--------+----+---------------+
g2ieeal7

g2ieeal72#

这种方法类似于@Shubham。使用F.transform进行过滤,然后使用F.aggregate进行聚合。

import sys
import pandas as pd

from pyspark import SparkContext, SQLContext
import pyspark.sql.functions as F
from pyspark import SparkContext, SQLContext
from pyspark.sql import SparkSession
from pyspark.sql.types import *
from pyspark.sql.window import Window

spark = SparkSession.builder \
    .appName("MyApp") \
    .getOrCreate()

sc = spark.sparkContext
sqlContext = SQLContext(sc)


df_data = {'username': ['bob', 'bob', 'bob', 'bob', 'bob', 'bob', 'bob', 'bob'],
           'session': [1, 2, 3, 4, 5, 6, 7, 8],
           'year_start': [2020, 2020, 2020, 2020, 2020, 2021, 2022, 2023],
           'year_end': [2020, 2020, 2020, 2020, 2021, 2021, 2022, 2023],
           'pass_new': [1, 0, 0, 0, 0, 1, 1, 0],
           'cumulative_pass_given': [0, 0, 0, 0, 0, 1, 2, 3],
           }
df_pandas = pd.DataFrame.from_dict(df_data)

df = sqlContext.createDataFrame(df_pandas)
print("Given dataframe")
df.show()

window_spec =Window.partitionBy('username').orderBy(F.col("year_start"), F.col("year_end")).rowsBetween(Window.unboundedPreceding, -1)

spark_df = df.withColumn("collect_values", F.collect_list(F.struct('year_end', 'pass_new')).over(window_spec))

print("grouped sum dataframe")
spark_df.show(n=100, truncate=False)

df_filtered = spark_df.withColumn("filtered_pass", F.transform( F.col("collect_values"),
                                                                lambda x : F.when( x.year_end < F.col("year_start"),
                                                                                  x.pass_new.cast("int"))
                                                                            .otherwise(0) )).cache()

print(" df_filtered dataframe")
df_filtered.show(n=100, truncate=False)

df_calculated_agg = df_filtered.withColumn("required_pass_using_aggregate", F.aggregate( F.col("filtered_pass"), F.lit(0) , lambda x, acc : acc + x))
print("df_calculated_agg")
df_calculated_agg.show(n=100, truncate=False)

输出量:

Given dataframe
+--------+-------+----------+--------+--------+---------------------+
|username|session|year_start|year_end|pass_new|cumulative_pass_given|
+--------+-------+----------+--------+--------+---------------------+
|     bob|      1|      2020|    2020|       1|                    0|
|     bob|      2|      2020|    2020|       0|                    0|
|     bob|      3|      2020|    2020|       0|                    0|
|     bob|      4|      2020|    2020|       0|                    0|
|     bob|      5|      2020|    2021|       0|                    0|
|     bob|      6|      2021|    2021|       1|                    1|
|     bob|      7|      2022|    2022|       1|                    2|
|     bob|      8|      2023|    2023|       0|                    3|
+--------+-------+----------+--------+--------+---------------------+

grouped sum dataframe
+--------+-------+----------+--------+--------+---------------------+-----------------------------------------------------------------------------+
|username|session|year_start|year_end|pass_new|cumulative_pass_given|collect_values                                                               |
+--------+-------+----------+--------+--------+---------------------+-----------------------------------------------------------------------------+
|bob     |1      |2020      |2020    |1       |0                    |[]                                                                           |
|bob     |2      |2020      |2020    |0       |0                    |[{2020, 1}]                                                                  |
|bob     |3      |2020      |2020    |0       |0                    |[{2020, 1}, {2020, 0}]                                                       |
|bob     |4      |2020      |2020    |0       |0                    |[{2020, 1}, {2020, 0}, {2020, 0}]                                            |
|bob     |5      |2020      |2021    |0       |0                    |[{2020, 1}, {2020, 0}, {2020, 0}, {2020, 0}]                                 |
|bob     |6      |2021      |2021    |1       |1                    |[{2020, 1}, {2020, 0}, {2020, 0}, {2020, 0}, {2021, 0}]                      |
|bob     |7      |2022      |2022    |1       |2                    |[{2020, 1}, {2020, 0}, {2020, 0}, {2020, 0}, {2021, 0}, {2021, 1}]           |
|bob     |8      |2023      |2023    |0       |3                    |[{2020, 1}, {2020, 0}, {2020, 0}, {2020, 0}, {2021, 0}, {2021, 1}, {2022, 1}]|
+--------+-------+----------+--------+--------+---------------------+-----------------------------------------------------------------------------+

 df_filtered dataframe
+--------+-------+----------+--------+--------+---------------------+-----------------------------------------------------------------------------+---------------------+
|username|session|year_start|year_end|pass_new|cumulative_pass_given|collect_values                                                               |filtered_pass        |
+--------+-------+----------+--------+--------+---------------------+-----------------------------------------------------------------------------+---------------------+
|bob     |1      |2020      |2020    |1       |0                    |[]                                                                           |[]                   |
|bob     |2      |2020      |2020    |0       |0                    |[{2020, 1}]                                                                  |[0]                  |
|bob     |3      |2020      |2020    |0       |0                    |[{2020, 1}, {2020, 0}]                                                       |[0, 0]               |
|bob     |4      |2020      |2020    |0       |0                    |[{2020, 1}, {2020, 0}, {2020, 0}]                                            |[0, 0, 0]            |
|bob     |5      |2020      |2021    |0       |0                    |[{2020, 1}, {2020, 0}, {2020, 0}, {2020, 0}]                                 |[0, 0, 0, 0]         |
|bob     |6      |2021      |2021    |1       |1                    |[{2020, 1}, {2020, 0}, {2020, 0}, {2020, 0}, {2021, 0}]                      |[1, 0, 0, 0, 0]      |
|bob     |7      |2022      |2022    |1       |2                    |[{2020, 1}, {2020, 0}, {2020, 0}, {2020, 0}, {2021, 0}, {2021, 1}]           |[1, 0, 0, 0, 0, 1]   |
|bob     |8      |2023      |2023    |0       |3                    |[{2020, 1}, {2020, 0}, {2020, 0}, {2020, 0}, {2021, 0}, {2021, 1}, {2022, 1}]|[1, 0, 0, 0, 0, 1, 1]|
+--------+-------+----------+--------+--------+---------------------+-----------------------------------------------------------------------------+---------------------+

df_calculated_agg
+--------+-------+----------+--------+--------+---------------------+-----------------------------------------------------------------------------+---------------------+-----------------------------+
|username|session|year_start|year_end|pass_new|cumulative_pass_given|collect_values                                                               |filtered_pass        |required_pass_using_aggregate|
+--------+-------+----------+--------+--------+---------------------+-----------------------------------------------------------------------------+---------------------+-----------------------------+
|bob     |1      |2020      |2020    |1       |0                    |[]                                                                           |[]                   |0                            |
|bob     |2      |2020      |2020    |0       |0                    |[{2020, 1}]                                                                  |[0]                  |0                            |
|bob     |3      |2020      |2020    |0       |0                    |[{2020, 1}, {2020, 0}]                                                       |[0, 0]               |0                            |
|bob     |4      |2020      |2020    |0       |0                    |[{2020, 1}, {2020, 0}, {2020, 0}]                                            |[0, 0, 0]            |0                            |
|bob     |5      |2020      |2021    |0       |0                    |[{2020, 1}, {2020, 0}, {2020, 0}, {2020, 0}]                                 |[0, 0, 0, 0]         |0                            |
|bob     |6      |2021      |2021    |1       |1                    |[{2020, 1}, {2020, 0}, {2020, 0}, {2020, 0}, {2021, 0}]                      |[1, 0, 0, 0, 0]      |1                            |
|bob     |7      |2022      |2022    |1       |2                    |[{2020, 1}, {2020, 0}, {2020, 0}, {2020, 0}, {2021, 0}, {2021, 1}]           |[1, 0, 0, 0, 0, 1]   |2                            |
|bob     |8      |2023      |2023    |0       |3                    |[{2020, 1}, {2020, 0}, {2020, 0}, {2020, 0}, {2021, 0}, {2021, 1}, {2022, 1}]|[1, 0, 0, 0, 0, 1, 1]|3                            |
+--------+-------+----------+--------+--------+---------------------+-----------------------------------------------------------------------------+---------------------+-----------------------------+

相关问题