如何获得Pypark中的平均值?

wbrvyc0a  于 2021-07-09  发布在  Spark
关注(0)|答案(1)|浏览(416)

我有一个sparkDataframe:

df = spark.createDataFrame([(10, "Hyundai"), (20, "alpha") ,(70,'Audio'), (1000,'benz'), (50,'Suzuki'),(60,'Lambo'),(30,'Bmw')],["Cars", "Brand"])

现在我想找出异常值,因为我使用了iqr,得到了上下两个值,如下图所示,然后找到了异常值:

lower, upper = -55.0 145.0
outliers= df.filter((df['Cars'] > upper) | (df['Cars'] < lower))
Cars    Brand
1000    benz

现在我想找出排除异常值的平均值,找出我用过的函数,什么时候,我得到了这样的错误

"TypeError: 'Column' object is not callable"

from pyspark.sql import functions as fun
mean = df.select(fun.when((df['Cars'] > upper) | (df['Cars'] < lower), fun.mean(df['Cars'].alias('mean')).collect()[0]['mean']))
print(mean)

是我的代码错了还是有更好的方法?

7kjnsjlb

7kjnsjlb1#

我想你不需要使用 when . 你只需做一个筛选,然后汇总平均值:

import pyspark.sql.functions as F

mean = df.filter((df['Cars'] <= upper) & (df['Cars'] >= lower)).agg(F.mean('cars').alias('mean'))

mean.show()
+----+
|mean|
+----+
|40.0|
+----+

如果你想用 when ,可以使用条件聚合:

mean = df.agg(F.mean(F.when((df['Cars'] <= upper) & (df['Cars'] >= lower), df['Cars'])).alias('mean'))

mean.show()
+----+
|mean|
+----+
|40.0|
+----+

要收集到变量,可以使用collect:

mean_collected = mean.collect()[0][0]

相关问题