在PySpark Dataframe 中优化筛选器+更新联接循环

h5qlskok  于 2022-09-21  发布在  Spark
关注(0)|答案(1)|浏览(152)

我们有一个PySpark Dataframe ,它表示一所有35名学生和3个班级的学校,每行代表一个学生。1班有20名学生,2班有10名学生,3班有5名学生。

我们想要比较这三个班级,因此我们将分配一个考试给每个班级中最少4名学生和最多班级人数的一半。这些学生必须随机挑选。

如何才能尽可能高效地执行这项任务?我在这里分享我到目前为止构建的代码。

students = students.withColumn("exam", sf.lit(None))
for class_ in students.select("classroom").distinct().collect():

    for group_ in ['TEST', 'NO_TEST']:
        subdf = students.filter(sf.col("classroom") == class_[0])

            if group_ == 'TEST':
                subdf_group = subdf.sample(False, 0.5).limit(4) 
                    .withColumn("exam", sf.lit("EXAM"))

             else:
                 subdf_group = subdf.filter(sf.isnull(sf.col("exam"))) 
                    .withColumn("exam", sf.lit("NO_EXAM"))

         students = self.update_df(students, subdf_group)

def update_df(self, df_, new_df_):
    """
    A left join that updates the values from the students df with the new
    values on exam column.
    """
    out_df = df_.alias('l') 
        .join(new_df_.select("student_id", "exam").alias('r'),
              on="student_id", how="left").select(
        "student_id",
        self.update_column("exam")
    )
    return out_df

def update_column(column_name: str, left: str ='l', right: str ='r'):
    """
    When joining two dfs with same column names, we keep the column values from
    the right dataframe when values on right are not null, else, we keep the
    values on the left column.
    """
    return sf.when(~sf.isnull(sf.col(f'{right}.{column_name}')),
                   sf.col(f'{right}.{column_name}')) 
        .otherwise(sf.col(f'{left}.{column_name}')).alias(column_name)

这是一个玩具的例子。实际上,我们在 Dataframe 中有135个类和总共400万行,使用我在上面分享的代码,任务运行得很差。

dgsult0t

dgsult0t1#

你分享的代码本质上非常平淡无奇。
最佳性能是在不使用Python循环和多重洗牌时。在共享的代码中,python循环使用数据集X中的类数创建一个Spark计划,并进行筛选器和联接(Join=繁重的随机操作),因此性能较差。

假设您有一个包含两列["class", "student"]的 Dataframe ,我将使用窗口函数来完成此操作。
Spark会将每个窗口函数分区(在我们的例子中是类)发送到不同的执行器,因此您将并行采样,而不需要每次都过滤大的DF。

from pyspark.sql.functions import col, row_number, rand, count
from pyspark.sql.window import Window

df 
.select(
    "*",
    row_number().over(Window.partitionBy('class').orderBy(rand(123))).alias('random_position'),
    count("*").over(Window.partitionBy('class')).alias('num_students_in_class'),
    ) 
.withColumn(
    "takes_test",
    (col("random_position") <= 4) &
     ((col("num_students_in_class") / 2) > col("random_position")))

相关问题