我有一个包含许多不同级别的字符串变量的数据集,还有一个需要求平均值的数值列。在得到平均值后,我需要按降序选择每个列的前N个级别,其中目标列的平均值最高。最后,对于每一列,我需要检查列中的值是否在前N个列表中,如果是,则保留它们,如果不是,则用字符串“Other”替换不在列表中的级别。结构数据看起来像这样,但有更多的列和许多行:
| 列A|色谱柱B|列C|目标|
| --|--|--|--|
| A1| B1| C1| 25 |
| A1| B2| C2| 50 |
| A2| B3| C2| 10 |
| A2| B3| C2| 15 |
现在,我正在运行一个for循环,按每列分组,并取Target列的平均值,一次一列:
import pyspark.sql.functions as F
cols_to_group = ['A','B','C']
top_n = 10
for c in cols_to_group:
agg_df = (
df
.groupBy(c)
.agg({"Target":"avg"})
.orderBy(F.col("avg(Target)").desc())
.select(c)
.limit(top_n)
)
levels = [lvl[0] for lvl in agg_df.collect()]
df = (
df
.withColumn(c, F.when(F.col(c).isin(levels), F.col(c)).otherwise(F.lit("Other"))
)
这是可行的,但它是非常缓慢的,这个问题似乎应该并行化相当容易。有没有一种方法可以让我并行地为每一列运行groupBy/aggregation,然后检查该列是否在列表中,如果是,则保留它,否则填写字符串“Other”?
1条答案
按热度按时间icnyk63a1#
如果每列的非重复值的数量不是太大,你可以分解这个矩阵,使它具有
(column, value, Target)
的形式,计算每个(column, value)
的均值,然后按column
聚合,并收集均值列表。最后你把列表切片,收集它并转换原始的df。