pyspark交叉验证中的问题

3phpmpom  于 2021-07-13  发布在  Spark
关注(0)|答案(1)|浏览(487)

我试图在下面的代码中交叉验证pyspark上的rf模型,并抛出错误:

from pyspark.ml import Pipeline
from pyspark.ml.classification import RandomForestClassifier
from pyspark.ml.evaluation import MulticlassClassificationEvaluator
from pyspark.ml.tuning import ParamGridBuilder, CrossValidator

# Your code

trainData = raw_data_ 
numFolds = 5 

rf = RandomForestClassifier(labelCol="Target", featuresCol="Scaled_features")
evaluator = MulticlassClassificationEvaluator() #    

pipeline = Pipeline(stages=[rf])
paramGrid = (ParamGridBuilder()\
    .addGrid(rf.numTrees, [3, 10])\
    .build())
crossval = CrossValidator(
    estimator=pipeline,
    estimatorParamMaps=paramGrid,
    evaluator=evaluator,
    numFolds=numFolds)

tr_model = crossval.fit(trainData)

但这会导致一个错误

我的原始数据变量是:

|            features|Position_Group|     Scaled_features|Target|
+--------------------+--------------+--------------------+------+
|[173.735992431640...|           FWD|[12.9261366722264...|     0|
|[188.975997924804...|           FWD|[14.0600087682323...|     0|
|[179.832000732421...|           FWD|[13.3796859647366...|     0|
|[155.752807617187...|           MID|[11.5881692110224...|     2|
|[176.783996582031...|           FWD|[13.1529113184815...|     0|
|[176.783996582031...|           MID|[13.1529113184815...|     2|
|[182.880004882812...|           FWD|[13.6064606109917...|     0|
|[182.880004882812...|           DEF|[13.6064606109917...|     1|
|[182.880004882812...|           FWD|[13.6064606109917...|     0|
|[182.880004882812...|           MID|[13.6064606109917...|     2|
|[188.975997924804...|           DEF|[14.0600087682323...|     1|
|[176.783996582031...|           MID|[13.1529113184815...|     2|
|[170.688003540039...|           MID|[12.6993631612409...|     2|
|[155.447998046875...|           FWD|[11.5654910652351...|     0|
|[188.975997924804...|           FWD|[14.0600087682323...|     0|
|[179.832000732421...|           MID|[13.3796859647366...|     2|
|[188.975997924804...|           MID|[14.0600087682323...|     2|
|[185.927993774414...|           FWD|[13.8332341219772...|     0|
|[176.783996582031...|           FWD|[13.1529113184815...|     0|
|[188.975997924804...|           DEF|[14.0600087682323...|     1|
+--------------------+--------------+--------------------+------+

有什么关于为什么和在哪里发生这个问题的建议吗?如何解决这个问题?
谢谢

qv7cva1a

qv7cva1a1#

错误显示
调用evaluate时出错。字段“label”不存在。
这说明评估者出了问题。在求值器的定义中,您没有指定label列,因此求值器尝试使用默认的“label”列,但该列不存在。
要解决这个问题,需要在示例化求值器时指定label列,就像对分类器所做的那样。例如

evaluator = MulticlassClassificationEvaluator(labelCol="Target")

相关问题