spark中的where子句,用于struct数组中的元素?

sirbozc5  于 2021-06-26  发布在  Hive
关注(0)|答案(1)|浏览(322)

我有两个案例:

case class Doc(posts: Seq[Post], test: String)
case class Post(postId: Int, createdTime: Long)

我创建了一个示例df:

val df = spark.sparkContext.parallelize(Seq(
Doc(Seq(
  Post(1, 1),
  Post(2, 3),
  Post(3, 8),
  Post(4, 15)
), null),
Doc(Seq(
  Post(5, 6),
  Post(6, 9),
  Post(7, 12),
  Post(8, 20)
), "hello") )).toDF()

所以我想要的是,返回带有帖子的在线文档,其中createtime在x和y之间。例如,对于x=2 et y=9,我希望这个结果与原始df的模式相同:

+--------------+
|         posts|
+--------------+
|[[2,3], [3,8]]|
|[[5,6], [6,9]]|
+--------------+

所以我尝试了很多组合的地方,但我没有工作。我试着用 map(_.filter(...)) ,但问题是我不想做 toDF().as[Doc] 有什么帮助吗?谢谢您

balp4ylt

balp4ylt1#

有几种方法可以做到这一点:
使用自定义项
使用爆炸和收集
使用databricks工具
自定义项
udf是一条捷径。您基本上创建了一个自定义函数来完成这项工作。与转换为数据集不同,它不会构造整个doc类,而是只处理相关数据:

def f(posts: Seq[Row]): Seq[Post] = {
  posts.map(r => Post(r.getAs[Int](0), r.getAs[Long](1))).filter(p => p.createdTime > 3 && p.createdTime < 9))
}
val u = udf(f _)
val filtered = df.withColumn("posts", u($"posts"))

使用分解和收集列表

df.withColumn("posts", explode($"posts")).filter($"posts.createdTime" > 3 && $"posts.createdTime" < 9).groupBy("test").agg(collect_list("posts").as("posts"))

这可能比前一个效率要低,但它是一个单行程序(在将来的某个时刻它可能会得到优化)。
使用databricks工具
如果你在databricks云上工作,你可以使用高阶函数。更多信息请参见此处。因为这不是一个一般的Spark选择,我不会去了。
希望将来他们能将它集成到标准的spark中(我在这个主题上找到了jira,但目前还不支持它)。

相关问题