sparkDataframe中模式(最常见元素)的聚合

toiithl6  于 2021-05-27  发布在  Spark
关注(0)|答案(1)|浏览(378)

在spark中,我使用的是一个库,我应该为它提供聚合,然后该库执行一系列join/groupby并在最后调用聚合。我试图避免违反封装(尽管必要时我可以),只需使用聚合(传统上是sum或min等)调用这个方法
在本例中,我尝试运行模式,但是,我不知道如何在聚合中运行。

zpgglvta

zpgglvta1#

这里有一个spark(2.1.0)udaf来计算给定列的统计模式:

package org.anish.spark.mostcommonvalue

import org.apache.spark.sql.Row
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types._

import scalaz.Scalaz._

/**
  * Spark User Defined Aggregate Function to calculate the most frequent value in a column. This is similar to
  * Statistical Mode. When there are two random values, this function selects any one. When calculating mode, both
  * these values together is considered as mode.
  *
  * Usage:
  *
  * DataFrame / DataSet DSL
  * val mostCommonValue = new MostCommonValue
  * df.groupBy("group_id").agg(mostCommonValue(col("mode_column")), mostCommonValue(col("city")))
  *
  * Spark SQL:
  * sqlContext.udf.register("mode", new MostCommonValue)
  * %sql
  * -- Use a group_by statement and call the UDAF.
  * select group_id, mode(id) from table group by group_id
  * 
  * Reference: https://docs.databricks.com/spark/latest/spark-sql/udaf-scala.html
  *
  * Created by anish on 26/05/17.
  */
class MostCommonValue extends UserDefinedAggregateFunction {

  // This is the input fields for your aggregate function.
  // We use StringType, because Mode can also be meaningfully applied on nominal data
  override def inputSchema: StructType =
  StructType(StructField("value", StringType) :: Nil)

  // This is the internal fields you keep for computing your aggregate.
  // We store the frequency of all the distinct element we encounter for the given attribute in this HashMap
  override def bufferSchema: StructType = StructType(
    StructField("frequencyMap", DataTypes.createMapType(StringType, LongType)) :: Nil
  )

  // This is the output type of your aggregation function.
  override def dataType: DataType = StringType

  override def deterministic: Boolean = true

  // This is the initial value for the buffer schema.
  override def initialize(buffer: MutableAggregationBuffer): Unit = {
    buffer(0) = Map[String, Long]()
  }

  // This is how to update your buffer schema given an input.
  override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
    buffer(0) = buffer.getAs[Map[String, Long]](0) |+| Map(input.getAs[String](0) -> 1L)
  }

  // This is how you merge two objects with the bufferSchema type.
  override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
    buffer1(0) = buffer1.getAs[Map[String, Long]](0) |+| buffer2.getAs[Map[String, Long]](0)
  }

  // This is where you output the final value, given the final value of your bufferSchema.
  override def evaluate(buffer: Row): String = {
    buffer.getAs[Map[String, Long]](0).maxBy(_._2)._1
  }
}

信用/来源:https://gist.github.com/anish749/6a815ed281f538068a0d3a20ca9044fa

相关问题