scala 如何获得一个有条件的最大从martrame

oxf4rvwz  于 12个月前  发布在  Scala
关注(0)|答案(2)|浏览(152)

我想根据特定的条件从一个数组中获取最大值,请使用spark scala。
我们有列“a”和列“B”,我想得到每列的最大值,但条件是:if column“a”contains 1,then 1,otherwise max(“a”)max(“B”)有一个重复:

df :                      res:
+---+---+                 +------+------+
|  a|  b|                 |  maxA|  maxB|
+---+---+                 +------+------+
|  1|  4|                 |     1|     5|
|  4|  3|                 +------+------+
|  6|  5|
|  8|  1|
+---+---+

    df :                   res:
+---+---+                 +------+------+
|  a|  b|                 |  maxA|  maxB|
+---+---+                 +------+------+
|  0|  4|                 |     8|     5|
|  4|  3|                 +------+------+
|  6|  5|
|  8|  1|
+---+---+

我试过这个,但没有工作:

val res = df.agg(max(
            when(col("a").contains(1),lit(1))
            .otherwise(col("a")))
            .as("maxA")
            ,max("b").as("maxB")
            )

我需要把我的条件写在AGG MAX函数里面,而不是外面。

kse8i1jr

kse8i1jr1#

为什么你的方法行不通

让我们看看when(col("a").contains(1), lit(.otherwise(col("a"))的输出:

df.select(col("a"), when(col("a").contains(1), lit(1)).otherwise(col("a"))).show()
+---+------------------------------------------+
|  a|CASE WHEN contains(a, 1) THEN 1 ELSE a END|
+---+------------------------------------------+
|  1|                                         1|
|  2|                                         2|
|  4|                                         4|
|  9|                                         9|
+---+------------------------------------------+

如您所见,输出列是相同的。所以基本上,它什么都不做。
如何做你想做的事
如果你只想在一个.agg中做这件事,那么使用默认的spark聚合函数是不可能的。好消息是,我们可以实现自定义聚合函数(查看this):

import scala.math

import org.apache.spark.sql.{Encoder, Encoders, SparkSession}
import org.apache.spark.sql.expressions.Aggregator
import org.apache.spark.sql.functions
import spark.implicits._

val df = spark.createDataset(Seq(
  (1, 6),
  (2, 0),
  (4, 3),
  (9, 1))
).select(col("_1").as("a"), col("_2").as("b"))

case class Buffer(var max: Int, var oneFound: Boolean)

object MaxIfNo1Found extends Aggregator[Int, Buffer, Int] {

    def zero: Buffer = Buffer(Int.MinValue, false)

    def reduce(buffer: Buffer, value: Int): Buffer = {
        buffer.max = math.max(buffer.max, value) 
        buffer.oneFound = buffer.oneFound | value == 1
        buffer
    }
    
    def merge(b1: Buffer, b2: Buffer): Buffer = {
        b1.max = math.max(b1.max, b2.max)
        b1.oneFound = b1.oneFound | b2.oneFound
        b1
    }

    def finish(reduction: Buffer): Int = if (reduction.oneFound) 1 else reduction.max

    def bufferEncoder: Encoder[Buffer] = Encoders.product

    def outputEncoder: Encoder[Int] = Encoders.scalaInt
}

spark.udf.register("maxIfNo1Found", functions.udaf(MaxIfNo1Found))

df.agg(expr("maxIfNo1Found(a)").as("maxA"), max("b").as("maxB")).show
+----+----+
|maxA|maxB|
+----+----+
|   1|   6|
+----+----+

就是这样!

如何更有意义地去做

定义一个自定义聚合似乎是非常不必要的。就这样做:

df.agg(
    max("a").as("maxA"),
    expr("bool_or(a = 1)").as("oneFound"),
    max("b").as("maxB")
  )
  .select(
    when(col("oneFound"), lit(1))
      .otherwise(col("maxA")).as("maxA"),  
    col("maxB")
  )
  .show
+----+----+
|maxA|maxB|
+----+----+
|   1|   6|
+----+----+

bool_or是一个聚合函数,通过给定列的所有元素计算布尔逻辑或。在这种情况下,如果a列中只有一个元素等于1,则oneFound将是True。要使其为False,您需要列a的所有值都不同于1(只有一个true会使孔表达式的计算结果为true)。最后,根据oneFound的值选择1maxA

aurhwmvo

aurhwmvo2#

df.select(
    when(max($"a" === lit(1)), lit(1)).otherwise(max("a")).as("maxA"),
    max($"b").as("maxB")
  )

相关问题