gpt4 book ai didi

apache-spark - 如何交叉验证RandomForest模型?

转载 作者:行者123 更新时间:2023-12-03 11:50:25 25 4
gpt4 key购买 nike

我想评估一个正在接受一些数据训练的随机森林。 Apache Spark中是否有任何实用程序可以执行相同操作,或者我必须手动执行交叉验证?

最佳答案

ML提供了 CrossValidator 类,该类可用于执行交叉验证和参数搜索。假设您的数据已经过预处理,则可以如下添加交叉验证:

import org.apache.spark.ml.Pipeline
import org.apache.spark.ml.tuning.{ParamGridBuilder, CrossValidator}
import org.apache.spark.ml.classification.RandomForestClassifier
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator

// [label: double, features: vector]
trainingData org.apache.spark.sql.DataFrame = ???
val nFolds: Int = ???
val numTrees: Int = ???
val metric: String = ???

val rf = new RandomForestClassifier()
.setLabelCol("label")
.setFeaturesCol("features")
.setNumTrees(numTrees)

val pipeline = new Pipeline().setStages(Array(rf))

val paramGrid = new ParamGridBuilder().build() // No parameter search

val evaluator = new MulticlassClassificationEvaluator()
.setLabelCol("label")
.setPredictionCol("prediction")
// "f1" (default), "weightedPrecision", "weightedRecall", "accuracy"
.setMetricName(metric)

val cv = new CrossValidator()
// ml.Pipeline with ml.classification.RandomForestClassifier
.setEstimator(pipeline)
// ml.evaluation.MulticlassClassificationEvaluator
.setEvaluator(evaluator)
.setEstimatorParamMaps(paramGrid)
.setNumFolds(nFolds)

val model = cv.fit(trainingData) // trainingData: DataFrame

使用PySpark:

from pyspark.ml import Pipeline
from pyspark.ml.classification import RandomForestClassifier
from pyspark.ml.tuning import CrossValidator, ParamGridBuilder
from pyspark.ml.evaluation import MulticlassClassificationEvaluator

trainingData = ... # DataFrame[label: double, features: vector]
numFolds = ... # Integer

rf = RandomForestClassifier(labelCol="label", featuresCol="features")
evaluator = MulticlassClassificationEvaluator() # + other params as in Scala

pipeline = Pipeline(stages=[rf])
paramGrid = (ParamGridBuilder.
.addGrid(rf.numTrees, [3, 10])
.addGrid(...) # Add other parameters
.build())

crossval = CrossValidator(
estimator=pipeline,
estimatorParamMaps=paramGrid,
evaluator=evaluator,
numFolds=numFolds)

model = crossval.fit(trainingData)

关于apache-spark - 如何交叉验证RandomForest模型?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/32769573/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com