如何对数组列的元素进行切片和求和?

moiiocjp  于 2021-05-27  发布在  Spark
关注(0)|答案(6)|浏览(575)

我想 sum (也可以执行其他聚合函数)使用sparksql对数组列执行。
我有一张table

+-------+-------+---------------------------------+
|dept_id|dept_nm|                      emp_details|
+-------+-------+---------------------------------+
|     10|Finance|        [100, 200, 300, 400, 500]|
|     20|     IT|                [10, 20, 50, 100]|
+-------+-------+---------------------------------+

我想把这个的值加起来 emp_details 列。
预期查询:

sqlContext.sql("select sum(emp_details) from mytable").show

预期结果

1500
180

我也应该能够对范围元素求和,比如:

sqlContext.sql("select sum(slice(emp_details,0,3)) from mytable").show

结果

600
80

当按预期对数组类型执行sum时,它显示sum预期参数是数值类型而不是数组类型。
我认为我们需要为此创建自定义项。但是怎么做呢?
我将面临任何与自定义项性能打击?除了udf之外还有其他解决方案吗?

e7arh2l6

e7arh2l61#

rdd方法丢失了,所以让我添加它。

val df = Seq((10, "Finance", Array(100,200,300,400,500)),(20, "IT", Array(10,20,50,100))).toDF("dept_id", "dept_nm","emp_details")

import scala.collection.mutable._

val rdd1 = df.rdd.map( x=> {val p = x.getAs[mutable.WrappedArray[Int]]("emp_details").toArray; Row.merge(x,Row(p.sum,p.slice(0,2).sum)) })

spark.createDataFrame(rdd1,df.schema.add(StructField("sumArray",IntegerType)).add(StructField("sliceArray",IntegerType))).show(false)

输出:

+-------+-------+-------------------------+--------+----------+
|dept_id|dept_nm|emp_details              |sumArray|sliceArray|
+-------+-------+-------------------------+--------+----------+
|10     |Finance|[100, 200, 300, 400, 500]|1500    |300       |
|20     |IT     |[10, 20, 50, 100]        |180     |30        |
+-------+-------+-------------------------+--------+----------+
xpcnnkqh

xpcnnkqh2#

因为spark 2.4你可以用 slice 功能:

import org.apache.spark.sql.functions.slice

val df = Seq(
  (10, "Finance", Seq(100, 200, 300, 400, 500)),
  (20, "IT", Seq(10, 20, 50, 100))
).toDF("dept_id", "dept_nm", "emp_details")

val dfSliced = df.withColumn(
   "emp_details_sliced",
   slice($"emp_details", 1, 3)
)

dfSliced.show(false)
+-------+-------+-------------------------+------------------+
|dept_id|dept_nm|emp_details              |emp_details_sliced|
+-------+-------+-------------------------+------------------+
|10     |Finance|[100, 200, 300, 400, 500]|[100, 200, 300]   |
|20     |IT     |[10, 20, 50, 100]        |[10, 20, 50]      |
+-------+-------+-------------------------+------------------+

与数组求和 aggregate :

dfSliced.selectExpr(
  "*", 
  "aggregate(emp_details, 0, (x, y) -> x + y) as details_sum",  
  "aggregate(emp_details_sliced, 0, (x, y) -> x + y) as details_sliced_sum"
).show
+-------+-------+--------------------+------------------+-----------+------------------+
|dept_id|dept_nm|         emp_details|emp_details_sliced|details_sum|details_sliced_sum|
+-------+-------+--------------------+------------------+-----------+------------------+
|     10|Finance|[100, 200, 300, 4...|   [100, 200, 300]|       1500|               600|
|     20|     IT|   [10, 20, 50, 100]|      [10, 20, 50]|        180|                80|
+-------+-------+--------------------+------------------+-----------+------------------+
bwleehnv

bwleehnv3#

这里有一个替代mtoto的答案而不使用 groupBy (我真的不知道哪一个最快:udf,mtoto解决方案还是我的,欢迎评论)
使用 UDF ,一般来说。有一个答案,你可能想读,这是一个很好的资源阅读自定义项。
现在对于您的问题,您可以避免使用自定义项。我会用一个 Column 用scala逻辑生成的表达式。
数据:

val df = Seq((10, "Finance", Array(100,200,300,400,500)),
                  (20, "IT", Array(10,  20, 50,100)))
          .toDF("dept_id", "dept_nm","emp_details")

你需要一些诡计才能穿越 ArrayType ,您可以使用解决方案来发现各种问题(有关详细信息,请参阅底部的编辑) slice 零件)。这是我的建议,但你可能会发现更好。首先取最大长度

val maxLength = df.select(size('emp_details).as("l")).groupBy().max("l").first.getInt(0)

然后使用它,当你有一个较短的数组时进行测试

val sumArray = (1 until maxLength)
      .map(i => when(size('emp_details) > i,'emp_details(i)).otherwise(lit(0)))
      .reduce(_ + _)
      .as("sumArray")

val res = df
  .select('dept_id,'dept_nm,'emp_details,sumArray)

结果:

+-------+-------+--------------------+--------+
|dept_id|dept_nm|         emp_details|sumArray|
+-------+-------+--------------------+--------+
|     10|Finance|[100, 200, 300, 4...|    1500|
|     20|     IT|   [10, 20, 50, 100]|     180|
+-------+-------+--------------------+--------+

我建议你看看 sumArray 去了解它在做什么。
编辑:当然我只看了一半的问题了。。。但是,如果您想更改要求和的项,您可以看到使用此解决方案会变得很明显(即,您不需要切片函数),只需更改 (0 until maxLength) 使用所需的索引范围:

def sumArray(from: Int, max: Int) = (from until max)
      .map(i => when(size('emp_details) > i,'emp_details(i)).otherwise(lit(0)))
      .reduce(_ + _)
      .as("sumArray")
dluptydi

dluptydi4#

Spark2.4.0

从spark2.4开始,sparksql支持高阶函数来操作复杂的数据结构,包括数组。
“现代”解决办法如下:

scala> input.show(false)
+-------+-------+-------------------------+
|dept_id|dept_nm|emp_details              |
+-------+-------+-------------------------+
|10     |Finance|[100, 200, 300, 400, 500]|
|20     |IT     |[10, 20, 50, 100]        |
+-------+-------+-------------------------+

input.createOrReplaceTempView("mytable")

val sqlText = "select dept_id, dept_nm, aggregate(emp_details, 0, (acc, value) -> acc + value) as sum from mytable"
scala> sql(sqlText).show
+-------+-------+----+
|dept_id|dept_nm| sum|
+-------+-------+----+
|     10|Finance|1500|
|     20|     IT| 180|
+-------+-------+----+

您可以在以下文章和视频中找到有关高阶函数的详细资料:
在ApacheSpark2.4中为复杂数据类型引入新的内置函数和高阶函数
在databricks上使用sql中的高阶函数处理嵌套数据
用herman van hovell(databricks)介绍spark-sql中的高阶函数

spark 2.3.2及更早版本

免责声明我不推荐这种方法(尽管它获得了最多的赞成票),因为sparksql执行的是反序列化 Dataset.map . 该查询强制spark反序列化数据并将其加载到jvm(从由spark在jvm外部管理的内存区域)。这将不可避免地导致更频繁的gcs,从而使性能更差。
一种解决办法是 Dataset sparksql和scala的结合可以展示其强大功能的解决方案。

scala> val inventory = Seq(
     |   (10, "Finance", Seq(100, 200, 300, 400, 500)),
     |   (20, "IT", Seq(10, 20, 50, 100))).toDF("dept_id", "dept_nm", "emp_details")
inventory: org.apache.spark.sql.DataFrame = [dept_id: int, dept_nm: string ... 1 more field]

// I'm too lazy today for a case class
scala> inventory.as[(Long, String, Seq[Int])].
  map { case (deptId, deptName, details) => (deptId, deptName, details.sum) }.
  toDF("dept_id", "dept_nm", "sum").
  show
+-------+-------+----+
|dept_id|dept_nm| sum|
+-------+-------+----+
|     10|Finance|1500|
|     20|     IT| 180|
+-------+-------+----+

我把切片部分留作练习,因为它同样简单。

mm5n2pyu

mm5n2pyu5#

基于zero323令人敬畏的答案;如果您有一个长整数数组,即bigint,您需要将初始值从0更改为bigint(0),如第一段所述,以便

dfSliced.selectExpr(
  "*", 
  "aggregate(emp_details, BIGINT(0), (x, y) -> x + y) as details_sum",  
  "aggregate(emp_details_sliced, BIGINT(0), (x, y) -> x + y) as details_sliced_sum"
).show
nwnhqdif

nwnhqdif6#

可能使用的方法 explode() 在你的 Array 列,从而按唯一键聚合输出。例如:

import sqlContext.implicits._
import org.apache.spark.sql.functions._

(mytable
  .withColumn("emp_sum",
    explode($"emp_details"))
  .groupBy("dept_nm")
  .agg(sum("emp_sum")).show)
+-------+------------+
|dept_nm|sum(emp_sum)|
+-------+------------+
|Finance|        1500|
|     IT|         180|
+-------+------------+

要仅选择数组中的特定值,我们可以使用链接问题的答案,并稍加修改即可应用:

val slice = udf((array : Seq[Int], from : Int, to : Int) => array.slice(from,to))

(mytable
  .withColumn("slice", 
    slice($"emp_details", 
      lit(0), 
      lit(3)))
  .withColumn("emp_sum",
    explode($"slice"))
  .groupBy("dept_nm")
  .agg(sum("emp_sum")).show)
+-------+------------+
|dept_nm|sum(emp_sum)|
+-------+------------+
|Finance|         600|
|     IT|          80|
+-------+------------+

数据:

val data = Seq((10, "Finance", Array(100,200,300,400,500)),
               (20, "IT", Array(10,20,50,100)))
val mytable = sc.parallelize(data).toDF("dept_id", "dept_nm","emp_details")

相关问题