scala 在用户定义函数(UDF)中拟合LogisticRegression

euoag5mw  于 12个月前  发布在  Scala
关注(0)|答案(1)|浏览(163)

我在Spark Scala中实现了以下代码:

import org.apache.spark.sql.SparkSession
import org.apache.spark.ml.classification._

object Hello {
    def main(args: Array[String]) = {

          val getLabel1Probability = udf((param1: Double, labeledEntries: Seq[Array[Double]]) => {

            val trainingData = labeledEntries.map(entry => (org.apache.spark.ml.linalg.Vectors.dense(entry(0)), entry(1))).toList.toDF("features", "label")
            val regression = new LogisticRegression()
            val fittingModel = regression.fit(trainingData)

            val prediction = fittingModel.predictProbability(org.apache.spark.ml.linalg.Vectors.dense(param1))
            val probability = prediction.toArray(1)

            probability
          })

          val df = Seq((1.0, Seq(Array(1.0, 0), Array(2.0, 1))), (3.0, Seq(Array(1.0, 0), Array(2.0, 1)))).toDF("Param1", "LabeledEntries")

          val dfWithLabel1Probability = df.withColumn(
                "Label1Probability", getLabel1Probability(
                  $"Param1",
                  $"LabeledEntries"
                )
          )
          display(dfWithLabel1Probability)
    }
}

Hello.main(Array())

当在Databricks的笔记本多节点集群(DBR(Databricks)13.2,Spark 3.4.0和Scala 2.12.)上运行它时,dfWithLabel1Probability的显示会显示出来。
我有以下问题:

  • 我的理解是,在创建trainingData框架时,我应该得到一个NullPointerException,因为_sqlContext在udf中为null。如果是这样,为什么我没有得到它?是不是和从数据库的笔记本上运行有关行为是非确定性的吗?
  • 如果在udf中不允许创建数据框,我如何用给定数据框列中的数据拟合LogisticRegression?在这个真实的示例中,我处理的是数百万行的数据,所以我更倾向于避免使用Dataset's collect()来将所有这些行都放入驱动程序的内存中。有别的选择吗?

谢谢.

vuktfyat

vuktfyat1#

对于第一个问题,如果你运行:

val largedf = spark.range(100000).selectExpr("cast(id as double) Param1", "array(array(1.0, 0), array(2.0, 1)) LabeledEntries")

val largedfWithLabel1Probability = largedf.withColumn(
    "Label1Probability", getLabel1Probability(
      $"Param1",
      $"LabeledEntries"
    )
)

display(largedfWithLabel1Probability)

它将npe,范围为1,但使用:

(1 until 1000).map(a => (a.toDouble, Seq.. )).toDF..

它至少会开始处理。这是因为toDF使用LocalRelation来构建不发送到执行器的数据,而Range使用LeafNodes(执行器),因此出现异常。
这是第二个值得作为单独的顶级问题提出的问题。

相关问题