Apache Spark 创建不可为空的数组列

qyzbxkaa  于 2023-10-23  发布在  Apache
关注(0)|答案(1)|浏览(128)

我在scala中使用spark(2.4)。我有一个数组,我试图用default值(空数组)替换空值(数组列)。

val emptyStringArray = udf(() => Array.empty[String],
DataTypes.createArrayType(DataTypes.StringType, false))

def ensureNonNullCol: DataFrame => DataFrame = inputDf => {

    inputDf.select(inputDf.schema.fields.map { f: StructField =>
      f.dataType match {
        case array: ArrayType => new Column(
          
          AssertNotNull(when(col(f.name).isNull,
            array.elementType match {
              case DataTypes.StringType => emptyStringArray()
            }).otherwise(col(f.name)).expr)

        ).as(f.name)
      }
    }: _*)

}

最后,我得到:

|-- StrAarrayColumn: array (nullable = false)
 |    |-- element: string (containsNull = true)

我怎么能有:

|-- StrAarrayColumn: array (nullable = false)
 |    |-- element: string (containsNull = false)

yshpjwxd

yshpjwxd1#

问题是,你的第一个框架有一个结构体,它包含一个可能包含null的字符串数组。现在你的ensureNonNullCol函数接收到一个带有一些结构体的输入框架,你只需要选择一些值,不改变你的框架的结构体,只需要返回它。在我开始解决方案之前,你的代码有3个要点。

  1. 1.只匹配许多情况中的一个可能情况是非常危险的,不鼓励的,并会导致匹配错误(在您的代码中,当您匹配ArrayType和StringType时,请注意)
    1. udf与空的输入参数是不鼓励的,并导致警告,在我的Spark版本,它导致运行时异常。
  2. 1.你可以在适当的位置返回一个空的字符串数组,而不是调用udf。
    无论如何,解决方案是在选择函数中所需的字段后也更新结构类型:
def ensureNonNullCol: DataFrame => DataFrame = inputDf => {
    val newStruct = StructType(inputDf.schema.map { field =>
      val newDataType = field.dataType match {
        case arr: ArrayType if arr.elementType == StringType => arr.copy(containsNull = false)
        case other => other
      }
      field.copy(dataType = newDataType)
    })
    val newDF = inputDf.select(inputDf.schema.fields.map { f: StructField =>
      f.dataType match {
        case array: ArrayType => new Column(

          AssertNotNull(when(col(f.name).isNull,
            array.elementType match {
              case DataTypes.StringType => Array.empty[String]
              case _ => col(f.name)
            }).otherwise(col(f.name)).expr)
        ).as(f.name)

        case _ => col(f.name)
      }
    }: _*)

    spark.createDataFrame(newDF.rdd, newStruct)

  }

相关问题