na.fill-in模型管道

jhkqcmku  于 2021-07-09  发布在  Spark
关注(0)|答案(0)|浏览(202)

我试图在管道中用0填充空值,然后将管道导出到pmml文件。
我的第一次尝试是尝试创建一个自定义转换器,但是我遇到了一个错误,说“impute\u to \u zero”对象没有属性“\u to \u java”,在这里研究了一点之后,看起来我需要创建自己的to\u java方法,但是我很难用代码做到这一点。
这是我的密码

from pyspark.sql import DataFrame
from pyspark.ml import Transformer

class impute_to_zero(Transformer):
    """
    A custom Transformer which converts all dataframe na to 0
    """

    def __init__(self, df: DataFrame) -> DataFrame:
        super(impute_to_zero, self).__init__()

    def _transform(self, df: DataFrame) -> DataFrame:
        df = df.na.fill(0)
        return df

from pyspark.ml import Pipeline
from pyspark.ml.classification import RandomForestClassifier
from pyspark.ml.feature import StringIndexer,VectorAssembler,SQLTransformer
from pyspark2pmml import PMMLBuilder, toPMMLBytes

# Prepare training documents from a list of (id, text, label) tuples.

training = spark.createDataFrame([
    (0, "abc",3, 1.0),
    (1, "b",None, 0.0),
    (2, "spark",8, 1.0),
    (3, "hadoop",4, 0.0)
], ["id", "category","numcol", "label"])

fillna = impute_to_zero(training)
indexer = StringIndexer(inputCol="category", outputCol="categoryIndex")
assembler = VectorAssembler(inputCols=["categoryIndex", "numcol"], outputCol="features")
rf =RandomForestClassifier(labelCol="label", featuresCol="features",numTrees=5,maxDepth=3)
pipeline = Pipeline(stages=[fillna, indexer, assembler, rf])

model = pipeline.fit(training)

pmmlBuilder = PMMLBuilder(sc,training,model)
pmmlBuilder.buildFile("/dbfs/tmp/test.pmml")

我的第二次尝试是使用sqltransformer,但是pyspark2pmml似乎有问题。我收到一条错误消息,指出illegalargumentexception:name(s)[numcol]与任何字段都不匹配。

from pyspark.ml import Pipeline
from pyspark.ml.classification import RandomForestClassifier
from pyspark.ml.feature import StringIndexer,VectorAssembler,SQLTransformer
from pyspark2pmml import PMMLBuilder, toPMMLBytes

# Prepare training documents from a list of (id, text, label) tuples.

training = spark.createDataFrame([
    (0, "abc",3, 1.0),
    (1, "b",None, 0.0),
    (2, "spark",8, 1.0),
    (3, "hadoop",4, 0.0)
], ["id", "category","numcol", "label"])

fillna = SQLTransformer(statement = 
"""select 
  category,
  case when numcol is null then 0 else numcol end as numcol,
  label
FROM __THIS__
""")
indexer = StringIndexer(inputCol="category", outputCol="categoryIndex")
assembler = VectorAssembler(inputCols=["categoryIndex", "numcol"], outputCol="features")
rf =RandomForestClassifier(labelCol="label", featuresCol="features",numTrees=5,maxDepth=3)
pipeline = Pipeline(stages=[fillna, indexer, assembler, rf])

# Fit the pipeline to training documents.

model = pipeline.fit(training)

pmmlBuilder = PMMLBuilder(sc,training,model)
pmmlBuilder.buildFile("/dbfs/tmp/test.pmml")

暂无答案!

目前还没有任何答案,快来回答吧!

相关问题