在scala中合并Map

0yycz8jy  于 2023-10-18  发布在  Scala
关注(0)|答案(3)|浏览(145)

我有一个包含列col1,col2,col3的嵌套框架。col1和col2是字符串。col3是下面定义的Map[String,String]

|-- col3: map (nullable = true)
 |    |-- key: string
 |    |-- value: string (valueContainsNull = true)

我已经按col1,col2分组,并使用collect_list聚合以获得Map数组并存储在col4中。

df.groupBy($"col1", $"col2").agg(collect_list($"col3").as("col4"))

 |-- col4: array (nullable = true)
 |    |-- element: map (containsNull = true)
 |    |    |-- key: string
 |    |    |-- value: string (valueContainsNull = true)

然而,我想得到col4作为一个单一的Map与所有的Map相结合。目前我有:

[[a->a1,b->b1],[c->c1]]

预期输出

[a->a1,b->b1,c->c1]

使用udf是理想的吗?
任何帮助都是感激不尽的。谢谢.

yrwegjxp

yrwegjxp1#

你可以使用aggregate和map_concat:

import org.apache.spark.sql.functions.{expr, collect_list}

val df = Seq(
  (1, Map("k1" -> "v1", "k2" -> "v3")),
  (1, Map("k3" -> "v3")),
  (2, Map("k4" -> "v4")),
  (2, Map("k6" -> "v6", "k5" -> "v5"))
).toDF("id", "data")

val mergeExpr = expr("aggregate(data, map(), (acc, i) -> map_concat(acc, i))")

df.groupBy("id").agg(collect_list("data").as("data"))
  .select($"id", mergeExpr.as("merged_data"))
  .show(false)

// +---+------------------------------+
// |id |merged_data                   |
// +---+------------------------------+
// |1  |[k1 -> v1, k2 -> v3, k3 -> v3]|
// |2  |[k4 -> v4, k6 -> v6, k5 -> v5]|
// +---+------------------------------+

使用map_concat,我们通过aggregate内置函数连接 data 列的所有Map项,该函数允许我们将聚合应用于列表的对。

  • 注意事项 *:在Spark 2.4.5上的map_concat的当前实现中,它允许相同键的共存。这很可能是一个错误,因为根据官方文档,这不是预期的行为。请注意这点。

如果你想避免这种情况,你也可以选择一个UDF:

import org.apache.spark.sql.functions.{collect_list, udf}

val mergeMapUDF = udf((data: Seq[Map[String, String]]) => data.reduce(_ ++ _))

df.groupBy("id").agg(collect_list("data").as("data"))
  .select($"id", mergeMapUDF($"data").as("merged_data"))
  .show(false)
  • 更新(2022-08-27)*

1.在Spark 3.3.0中,上面的代码不起作用,并抛出以下异常:

AnalysisException: cannot resolve 'aggregate(`data`, map(), lambdafunction(map_concat(namedlambdavariable(), namedlambdavariable()), namedlambdavariable(), namedlambdavariable()), lambdafunction(namedlambdavariable(), namedlambdavariable()))' due to data type mismatch: argument 3 requires map<null,null> type, however, 'lambdafunction(map_concat(namedlambdavariable(), namedlambdavariable()), namedlambdavariable(), namedlambdavariable())' is of map<string,string> type.;
Project [id#110, aggregate(data#119, map(), lambdafunction(map_concat(cast(lambda acc#122 as map<string,string>), lambda i#123), lambda acc#122, lambda i#123, false), lambdafunction(lambda id#124, lambda id#124, false)) AS aggregate(data, map(), lambdafunction(map_concat(namedlambdavariable(), namedlambdavariable()), namedlambdavariable(), namedlambdavariable()), lambdafunction(namedlambdavariable(), namedlambdavariable()))#125]
+- Aggregate [id#110], [id#110, collect_list(data#111, 0, 0) AS data#119]
   +- Project [_1#105 AS id#110, _2#106 AS data#111]
      +- LocalRelation [_1#105, _2#106]

似乎map()被初始化为map<null,null>,而map<string,string>是预期的。
要解决这个问题,只需使用cast(map() as map<string, string>)显式地将map()转换为map<string, string>
下面是更新后的代码:

val mergeExpr = expr("aggregate(data, cast(map() as map<string,
string>), (acc, i) -> map_concat(acc, i))")

df.groupBy("id").agg(collect_list("data").as("data"))
  .select($"id", mergeExpr)
  .show(false)

1.关于相同的密钥错误,这似乎在最新版本中得到了修复。如果你尝试添加相同的键,会抛出异常:

Caused by: RuntimeException: Duplicate map key k5 was found, please check the input data. If you want to remove the duplicated keys, you can set spark.sql.mapKeyDedupPolicy to LAST_WIN so that the key inserted at last takes precedence.
5lhxktic

5lhxktic2#

你可以在没有UDF的情况下实现它。让我们创建你的框架:

val df = Seq(Seq(Map("a" -> "a1", "b" -> "b1"), Map("c" -> "c1", "d" -> "d1"))).toDF()
df.show(false)
df.printSchema()

产出:

+----------------------------------------+
|value                                   |
+----------------------------------------+
|[[a -> a1, b -> b1], [c -> c1, d -> d1]]|
+----------------------------------------+

root
 |-- value: array (nullable = true)
 |    |-- element: map (containsNull = true)
 |    |    |-- key: string
 |    |    |-- value: string (valueContainsNull = true)

如果你的数组包含2个元素,就使用map_concat

df.select(map_concat('value.getItem(0), 'value.getItem(1))).show(false)

或者这个(我不知道如何动态地从0循环到'值数组类型列大小,这可能是最短的解决方案)

df.select(map_concat((for {i <- 0 to 1} yield 'value.getItem(i)): _*)).show(false)

否则,如果您的数组包含多个map,并且大小未知,您可以尝试以下方法:

val df2 = df.map(s => {
    val list = s.getList[Map[String, String]](0)
    var map = Map[String, String]()
    for (i <- 0 to list.size() - 1) {
      map = map ++ list.get(i)
    }
    map
  })

  df2.show(false)
  df2.printSchema()

产出:

+------------------------------------+
|value                               |
+------------------------------------+
|[a -> a1, b -> b1, c -> c1, d -> d1]|
+------------------------------------+

root
 |-- value: map (nullable = true)
 |    |-- key: string
 |    |-- value: string (valueContainsNull = true)
bttbmeg0

bttbmeg03#

如果记录的数量较少,那么您可以将它们分解并收集为struct(),然后再次使用map_from_entries

val df = Seq(Seq(Map("a" -> "a1", "b" -> "b1"), Map("c" -> "c1", "d" -> "d1"))).toDF()
df.show(false)
df.printSchema()

+----------------------------------------+
|value                                   |
+----------------------------------------+
|[{a -> a1, b -> b1}, {c -> c1, d -> d1}]|
+----------------------------------------+

root
 |-- value: array (nullable = true)
 |    |-- element: map (containsNull = true)
 |    |    |-- key: string
 |    |    |-- value: string (valueContainsNull = true)

df.createOrReplaceTempView("items")

val df2 = spark.sql("""

with t1 (select value from items),
     t2 (select value, explode(value) m1 from t1 ),
     t3 (select value, explode(m1) (k,v) from t2 ),
     t4 (select value, struct(k,v) r1 from t3 ),
     t5 (select collect_list(r1) r2 from t4 )
     select map_from_entries(r2) merged_data from t5
    """)
df2.show(false)
df2.printSchema

+------------------------------------+
|merged_data                         |
+------------------------------------+
|{a -> a1, b -> b1, c -> c1, d -> d1}|
+------------------------------------+

root
 |-- merged_data: map (nullable = false)
 |    |-- key: string
 |    |-- value: string (valueContainsNull = true)

注意,当我们在group-by中使用“value”时,spark抛出org.apache.spark.sql.AnalysisException: expression t4.value cannot be used as a grouping expression because its data type array<map<string,string>> is not an orderable data type.
让我们采取abiratsis样本数据。这里我们必须在group-by中使用id列,否则所有的map元素将合并在一起。

val df = Seq(
  (1, Map("k1" -> "v1", "k2" -> "v3")),
  (1, Map("k3" -> "v3")),
  (2, Map("k4" -> "v4")),
  (2, Map("k6" -> "v6", "k5" -> "v5"))
).toDF("id", "data")
df.show(false)
df.printSchema()

+---+--------------------+
|id |data                |
+---+--------------------+
|1  |{k1 -> v1, k2 -> v3}|
|1  |{k3 -> v3}          |
|2  |{k4 -> v4}          |
|2  |{k6 -> v6, k5 -> v5}|
+---+--------------------+

root
 |-- id: integer (nullable = false)
 |-- data: map (nullable = true)
 |    |-- key: string
 |    |-- value: string (valueContainsNull = true)

df.createOrReplaceTempView("items")

val df2 = spark.sql("""

with t1 (select id, data from items),
     t2 (select id, explode(data) (k,v) from t1 ),
     t3 (select id, struct(k,v) r1 from t2 ),
     t4 (select id, collect_list(r1) r2 from t3 group by id )
     select id, map_from_entries(r2) merged_data from t4
    """)
df2.show(false)
df2.printSchema

+---+------------------------------+
|id |merged_data                   |
+---+------------------------------+
|1  |{k1 -> v1, k2 -> v3, k3 -> v3}|
|2  |{k4 -> v4, k6 -> v6, k5 -> v5}|
+---+------------------------------+

root
 |-- id: integer (nullable = false)
 |-- merged_data: map (nullable = false)
 |    |-- key: string
 |    |-- value: string (valueContainsNull = true)

相关问题