如何加载训练过的randomforestclassificationmodel?

fcipmucu  于 2021-07-03  发布在  Java
关注(0)|答案(1)|浏览(426)

我训练并测试了一个ml模型(gbtclassionmodel或randomforestclassificationmodel)。然后我想保存经过训练的模型以备将来使用。所以我做了以下工作:

model.save("...");

以gbtclassificationmodel为例,保存之后。保存的文件是一个包含“数据、元数据和树元数据”的目录。我的问题是如何使用这个保存的模型以备将来使用?例如,我想做如下操作:

model = spark.load("...");
 Dataset<Row> predict_data= model_model.transform(dataset_test1)

有什么建议吗?谢谢您。
更新:
结果很简单:

GBTClassificationModel model1 = GBTClassificationModel.load("...");
 Dataset<Row> predict_data= model1.transform(dataset_test)
h7appiyu

h7appiyu1#

应该使用randomforestclassificationmodel.load方法。
load(path:string):randomforestclassificationmodel从输入路径读取一个ml示例,这是 read.load(path) .
在scala中,在您的情况下,它将如下所示:

import org.apache.spark.ml.classification.RandomForestClassificationModel
val model = RandomForestClassificationModel.load("/analytics_shared/qoe/km_model")

我强烈建议使用spark mllib的ml管道功能:
ml管道提供了一组统一的、构建在Dataframe之上的高级api,帮助用户创建和调优实用的机器学习管道。
使用ml管道,您只需简单地替换 RandomForestClassificationModel 管道模型。

import org.apache.spark.ml.PipelineModel
val model = PipelineModel.load("...")

相关问题