javaudf处理数组列

vsdwdz23  于 2021-05-27  发布在  Spark
关注(0)|答案(1)|浏览(467)

我正在编写一个javaudf来处理一个数组类型的列。
其目的是处理字符串数组以选择长度最短的字符串

sqlContext.udf().register("NAME_SELECTOR", (UDF1<List<String>, String>) brandNames -> {
                          brandNames.sort(Comparator.comparing(String::length));
                          return brandNames.get(0);},DataTypes.StringType);

错误与udf函数的输入类型有关。我知道在scala我需要 Seq[String] 作为输入类型,java如何?
以下是错误消息: java.lang.ClassCastException: scala.collection.mutable.WrappedArray$ofRef cannot be cast to java.util.List

cigdeys3

cigdeys31#

试试这个-
使用 scala.collection.mutable.WrappedArray 并使用 JavaConverters 然后使用比较器对其进行排序并获取第一个最短的字符串-

Dataset<Row> df = spark.sql("select array('abc', 'ab', 'a') arr");
        df.printSchema();
        df.show(false);
        /**
         * root
         *  |-- arr: array (nullable = false)
         *  |    |-- element: string (containsNull = false)
         *
         * +------------+
         * |arr         |
         * +------------+
         * |[abc, ab, a]|
         * +------------+
         */

        // scala.collection.mutable.WrappedArray
        UserDefinedFunction shortestStringUdf = udf((WrappedArray<String> arr)  -> {
                    List<String> strings = new ArrayList<>(JavaConverters
                            .asJavaCollectionConverter(arr)
                            .asJavaCollection());
                    strings.sort(Comparator.comparing(String::length));
                    return strings.get(0);
                }
                , DataTypes.StringType);
        spark.udf().register("shortestString", shortestStringUdf);

        df.withColumn("a", expr("shortestString(arr)"))
        .show(false);
        /**
         * +------------+---+
         * |arr         |a  |
         * +------------+---+
         * |[abc, ab, a]|a  |
         * +------------+---+
         */

如果你在 spark>=2.4 使用高阶函数来获得相同的结果 without udf 如下所示-

// spark>=2.4
        df.withColumn("arr_length", expr("TRANSFORM(arr, x -> length(x))"))
                .withColumn("a", expr("array_sort(arrays_zip(arr_length, arr))[0].arr"))
                .show(false);
        /**
         * +------------+----------+---+
         * |arr         |arr_length|a  |
         * +------------+----------+---+
         * |[abc, ab, a]|[3, 2, 1] |a  |
         * +------------+----------+---+
         */

相关问题