如何从pyspark管道(stringindexer->onehotencoder->vectorassembler)获取原始的分类变量值?

axzmvihb  于 2021-05-17  发布在  Spark
关注(0)|答案(0)|浏览(223)

我的分类变量中有很多不同的级别,我想知道我的模型指向哪个值:

indexer = StringIndexer(inputCols=["a","b"], outputCols=["a_ind", "b_ind"])
df = indexer.setHandleInvalid("keep").fit(df).transform(df)

encoder = OneHotEncoder(inputCols=["a_ind", "b_ind"],
                        outputCols=["a_vec", "b_vec"])
df = encoder.fit(df).transform(df)

df = VectorAssembler(inputCols=["a_vec","b_vec"],
                     outputCol="features").transform(df)

“功能”列如下所示:

0: 0
1: 45568
2: 
0: 1
1: 2923
3: 
0: 1
1: 1

当我打印决策树模型时,我看到下面的树

print(dtModel.toDebugString)

DecisionTreeClassificationModel: uid=DecisionTreeClassifier_e9f4ea2ba51e, depth=5, numNodes=13, numClasses=2, numFeatures=45568
  If (feature 0 in {1.0})
   Predict: 0.0
  Else (feature 0 not in {1.0})
   If (feature 1 in {1.0})
    Predict: 0.0
   Else (feature 1 not in {1.0})
    If (feature 3 in {1.0})
     Predict: 0.0
    Else (feature 3 not in {1.0})
     If (feature 5 in {1.0})
      If (feature 3664 in {1.0})
       Predict: 1.0
      Else (feature 3664 not in {1.0})
       Predict: 0.0
     Else (feature 5 not in {1.0})
      If (feature 2933 in {1.0})
       Predict: 1.0
      Else (feature 2933 not in {1.0})
       Predict: 0.0

我希望能够使用此树来知道哪些值触发1.0的预测(功能部件3664、功能部件2993、功能部件5、功能部件3等)。我怎样才能回到a列或b列中的原始字符串?

暂无答案!

目前还没有任何答案,快来回答吧!

相关问题