spark:flatmap和countvectorizer管道

de90aj5v  于 2021-05-27  发布在  Spark
关注(0)|答案(1)|浏览(444)

我正在处理管道,并尝试在将列值传递给 CountVectorizer .
为此,我做了一个定制变压器。

class FlatMapTransformer(override val uid: String)
  extends Transformer {
  /**
   * Param for input column name.
   * @group param
   */
  final val inputCol = new Param[String](this, "inputCol", "The input column")
  final def getInputCol: String = $(inputCol)

  /**
   * Param for output column name.
   * @group param
   */
  final val outputCol = new Param[String](this, "outputCol", "The output column")
  final def getOutputCol: String = $(outputCol)

  def setInputCol(value: String): this.type = set(inputCol, value)
  def setOutputCol(value: String): this.type = set(outputCol, value)

  def this() = this(Identifiable.randomUID("FlatMapTransformer"))

  private val flatMap: String => Seq[String] = { input: String =>
    input.split(",")
  }

  override def copy(extra: ParamMap): SplitString = defaultCopy(extra)

  override def transform(dataset: Dataset[_]): DataFrame = {
    val flatMapUdf = udf(flatMap)
    dataset.withColumn($(outputCol), explode(flatMapUdf(col($(inputCol)))))
  }

  override def transformSchema(schema: StructType): StructType = {
    val dataType = schema($(inputCol)).dataType
    require(
      dataType.isInstanceOf[StringType],
      s"Input column must be of type StringType but got ${dataType}")
    val inputFields = schema.fields
    require(
      !inputFields.exists(_.name == $(outputCol)),
      s"Output column ${$(outputCol)} already exists.")

    DataTypes.createStructType(
      Array(
        DataTypes.createStructField($(outputCol), DataTypes.StringType, false)))
  }
}

代码似乎是合法的,但当我试图将它与其他操作链接时,问题就出现了。这是我的管道:

val train = reader.readTrainingData()

val cat_features = getFeaturesByType(taskConfig, "categorical")
val num_features = getFeaturesByType(taskConfig, "numeric")
val cat_ohe_features = getFeaturesByType(taskConfig, "categorical", Some("ohe"))
val cat_features_string_index = cat_features.
  filter { feature: String => !cat_ohe_features.contains(feature) }

val catIndexer = cat_features_string_index.map {
  feature =>
    new StringIndexer()
      .setInputCol(feature)
      .setOutputCol(feature + "_index")
      .setHandleInvalid("keep")
}

    val flatMapper = cat_ohe_features.map {
      feature =>
        new FlatMapTransformer()
          .setInputCol(feature)
          .setOutputCol(feature + "_transformed")
    }

    val countVectorizer = cat_ohe_features.map {
      feature =>

        new CountVectorizer()
          .setInputCol(feature + "_transformed")
          .setOutputCol(feature + "_vectorized")
          .setVocabSize(10)
    }

// val countVectorizer = cat_ohe_features.map {
//   feature =>
//
//     val flatMapper = new FlatMapTransformer()
//       .setInputCol(feature)
//       .setOutputCol(feature + "_transformed")
// 
//     new CountVectorizer()
//       .setInputCol(flatMapper.getOutputCol)
//       .setOutputCol(feature + "_vectorized")
//       .setVocabSize(10)
// }

val cat_features_index = cat_features_string_index.map {
  (feature: String) => feature + "_index"
}

val count_vectorized_index = cat_ohe_features.map {
  (feature: String) => feature + "_vectorized"
}

val catFeatureAssembler = new VectorAssembler()
  .setInputCols(cat_features_index)
  .setOutputCol("cat_features")

val oheFeatureAssembler = new VectorAssembler()
  .setInputCols(count_vectorized_index)
  .setOutputCol("cat_ohe_features")

val numFeatureAssembler = new VectorAssembler()
  .setInputCols(num_features)
  .setOutputCol("num_features")

val featureAssembler = new VectorAssembler()
  .setInputCols(Array("cat_features", "num_features", "cat_ohe_features_vectorized"))
  .setOutputCol("features")

val pipelineStages = catIndexer ++ flatMapper ++ countVectorizer ++
  Array(
    catFeatureAssembler,
    oheFeatureAssembler,
    numFeatureAssembler,
    featureAssembler)

val pipeline = new Pipeline().setStages(pipelineStages)
pipeline.fit(dataset = train)

运行此代码时,我收到一个错误: java.lang.IllegalArgumentException: Field "my_ohe_field_trasformed" does not exist. ```
[info] java.lang.IllegalArgumentException: Field "from_expdelv_areas_transformed" does not exist.

[info] at org.apache.spark.sql.types.StructType$$anonfun$apply$1.apply(StructType.scala:266)
[info] at org.apache.spark.sql.types.StructType$$anonfun$apply$1.apply(StructType.scala:266)

[info] at scala.collection.MapLike$class.getOrElse(MapLike.scala:128)

[info] at scala.collection.AbstractMap.getOrElse(Map.scala:59)

[info] at org.apache.spark.sql.types.StructType.apply(StructType.scala:265)

[info] at org.apache.spark.ml.util.SchemaUtils$.checkColumnTypes(SchemaUtils.scala:56)

[info] at org.apache.spark.ml.feature.CountVectorizerParams$class.validateAndTransformSchema(CountVectorizer.scala:75)

[info] at org.apache.spark.ml.feature.CountVectorizer.validateAndTransformSchema(CountVectorizer.scala:123)

[info] at org.apache.spark.ml.feature.CountVectorizer.transformSchema(CountVectorizer.scala:188)

当我取消注解 `stringSplitter` 以及 `countVectorizer` 我的变压器出现了错误 `java.lang.IllegalArgumentException: Field "my_ohe_field" does not exist.` 在 `val dataType = schema($(inputCol)).dataType` 调用的结果 `pipeline.getStages` :

strIdx_3c2630a738f0

strIdx_0d76d55d4200

FlatMapTransformer_fd8595c2969c

FlatMapTransformer_2e9a7af0b0fa

cntVec_c2ef31f00181

cntVec_68a78eca06c9

vecAssembler_a81dd9f43d56

vecAssembler_b647d348f0a0

vecAssembler_b5065a22d5c8

vecAssembler_d9176b8bb593

我可能走错了路。如有任何意见,我们将不胜感激。
btxsgosb

btxsgosb1#

你的 FlatMapTransformer #transform 是不正确的,当您选择only on时,会删除/忽略所有其他列 outputCol ###请将您的方法修改为-

override def transform(dataset: Dataset[_]): DataFrame = {
     val flatMapUdf = udf(flatMap)
    dataset.withColumn($(outputCol), explode(flatMapUdf(col($(inputCol)))))
  }

另外,修改您的 transformSchema 在检查其数据类型之前先检查输入列-

override def transformSchema(schema: StructType): StructType = {
require(schema.names.contains($(inputCol)), "inputCOl is not there in the input dataframe")
//... rest as it is
}

根据评论更新-1

请修改 copy 方法(尽管它不是您面临的异常的原因)-

override def copy(extra: ParamMap): FlatMapTransformer = defaultCopy(extra)

请注意 CountVectorizer 获取具有类型为的列的列 ArrayType(StringType, true/false) 自从 FlatMapTransformer 输出列成为 CountVectorizer ,您需要确保 FlatMapTransformer 必须是 ArrayType(StringType, true/false) . 我想,不是这样的,你今天的代码如下-

override def transform(dataset: Dataset[_]): DataFrame = {
    val flatMapUdf = udf(flatMap)
    dataset.withColumn($(outputCol), explode(flatMapUdf(col($(inputCol)))))
  }

这个 explode 函数转换 array<string>string 因此变压器的输出变为 StringType . 您可能想将此代码更改为-

override def transform(dataset: Dataset[_]): DataFrame = {
    val flatMapUdf = udf(flatMap)
    dataset.withColumn($(outputCol), flatMapUdf(col($(inputCol))))
  }

修改 transformSchema 要输出的方法 ArrayType(StringType) ```
override def transformSchema(schema: StructType): StructType = {
val dataType = schema($(inputCol)).dataType
require(
dataType.isInstanceOf[StringType],
s"Input column must be of type StringType but got ${dataType}")
val inputFields = schema.fields
require(
!inputFields.exists(_.name == $(outputCol)),
s"Output column ${$(outputCol)} already exists.")

  schema.add($(outputCol), ArrayType(StringType))
}
将向量汇编程序更改为-

val featureAssembler = new VectorAssembler()
.setInputCols(Array("cat_features", "num_features", "cat_ohe_features"))
.setOutputCol("features")

我试着在虚拟Dataframe上执行你的管道,效果很好。完整代码请参考此要点。

相关问题