使用spark 1.5和scala 2.10.6
我试图通过一个字符串数组字段“tags”来过滤Dataframe。正在查找标记为“private”的所有行。
val report = df.select("*")
.where(df("tags").contains("private"))
得到:
线程“main”org.apache.spark.sql.analysisexception中出现异常:由于数据类型不匹配,无法解析“contains(tags,private)”:参数1需要字符串类型,但是,“tags”是数组类型。;
过滤法更合适吗?
更新时间:
数据来自cassandra adapter,但最简单的示例显示了我正在尝试执行的操作,也得到了上述错误:
def testData (sc: SparkContext): DataFrame = {
val stringRDD = sc.parallelize(Seq("""
{ "name": "ed",
"tags": ["red", "private"]
}""",
"""{ "name": "fred",
"tags": ["public", "blue"]
}""")
)
val sqlContext = new org.apache.spark.sql.SQLContext(sc)
import sqlContext.implicits._
sqlContext.read.json(stringRDD)
}
def run(sc: SparkContext) {
val df1 = testData(sc)
df1.show()
val report = df1.select("*")
.where(df1("tags").contains("private"))
report.show()
}
更新:标签数组可以是任意长度,“private”标签可以位于任意位置
更新:一个有效的解决方案:udf
val filterPriv = udf {(tags: mutable.WrappedArray[String]) => tags.contains("private")}
val report = df1.filter(filterPriv(df1("tags")))
2条答案
按热度按时间ddarikpa1#
我想如果你用
where(array_contains(...))
会有用的。我的结果是:请注意,如果你写
where(array_contains(df("tags"), "private"))
,但如果你写where(df("tags").array_contains("private"))
(更直接地类似于你最初写的)它失败了array_contains is not a member of org.apache.spark.sql.Column
. 查看源代码Column
,我看到有些事情要处理contains
(构建一个Contains
例如)但不是array_contains
. 也许这是疏忽。vbkedwbf2#
您可以使用ordinal来引用json数组的
df("tags")(0)
. 这是一个工作样本