gpt4 book ai didi

scala - 如何访问由 RandomForestClassifier (spark.ml-version) 创建的模型中的单个树?

转载 作者:行者123 更新时间:2023-12-05 00:19:15 25 4
gpt4 key购买 nike

如何访问 Spark ML 生成的模型中的单个树 RandomForestClassifier ?我正在使用 Scala 版本的 RandomForestClassifier。

最佳答案

其实它有trees属性:

import org.apache.spark.ml.attribute.NominalAttribute
import org.apache.spark.ml.classification.{
RandomForestClassificationModel, RandomForestClassifier,
DecisionTreeClassificationModel
}

val meta = NominalAttribute
.defaultAttr
.withName("label")
.withValues("0.0", "1.0")
.toMetadata

val data = sqlContext.read.format("libsvm")
.load("data/mllib/sample_libsvm_data.txt")
.withColumn("label", $"label".as("label", meta))

val rf: RandomForestClassifier = new RandomForestClassifier()
.setLabelCol("label")
.setFeaturesCol("features")

val trees: Array[DecisionTreeClassificationModel] = rf.fit(data).trees.collect {
case t: DecisionTreeClassificationModel => t
}

正如你所看到的,唯一的问题是让类型正确,这样我们就可以实际使用这些:
trees.head.transform(data).show(3)
// +-----+--------------------+-------------+-----------+----------+
// |label| features|rawPrediction|probability|prediction|
// +-----+--------------------+-------------+-----------+----------+
// | 0.0|(692,[127,128,129...| [33.0,0.0]| [1.0,0.0]| 0.0|
// | 1.0|(692,[158,159,160...| [0.0,59.0]| [0.0,1.0]| 1.0|
// | 1.0|(692,[124,125,126...| [0.0,59.0]| [0.0,1.0]| 1.0|
// +-----+--------------------+-------------+-----------+----------+
// only showing top 3 rows

备注 :

如果您使用管道,您也可以提取单个树:
import org.apache.spark.ml.Pipeline

val model = new Pipeline().setStages(Array(rf)).fit(data)

// There is only one stage and know its type
// but lets be thorough
val rfModelOption = model.stages.headOption match {
case Some(m: RandomForestClassificationModel) => Some(m)
case _ => None
}

val trees = rfModelOption.map {
_.trees // ... as before
}.getOrElse(Array())

关于scala - 如何访问由 RandomForestClassifier (spark.ml-version) 创建的模型中的单个树?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/36314436/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com