gpt4 book ai didi

scala - 在 Apache Spark 中,RandomForestClassifier 的输入带有无效标签列错误

转载 作者:行者123 更新时间:2023-11-30 08:27:36 25 4
gpt4 key购买 nike

我正在尝试使用 SCALA 中的随机森林分类器模型通过 5 倍交叉验证来找到准确性。但我在运行时遇到以下错误:

java.lang.IllegalArgumentException: RandomForestClassifier was given input with invalid label column label, without the number of classes specified. See StringIndexer.

在行中出现上述错误---> val cvModel = cv.fit(trainingData)

我使用随机森林对数据集进行交叉验证的代码如下:

import org.apache.spark.ml.Pipeline
import org.apache.spark.ml.tuning.{ParamGridBuilder, CrossValidator}
import org.apache.spark.ml.classification.RandomForestClassifier
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.regression.LabeledPoint

val data = sc.textFile("exprogram/dataset.txt")
val parsedData = data.map { line =>
val parts = line.split(',')
LabeledPoint(parts(41).toDouble,
Vectors.dense(parts(0).split(',').map(_.toDouble)))
}


val splits = parsedData.randomSplit(Array(0.6, 0.4), seed = 11L)
val training = splits(0)
val test = splits(1)

val trainingData = training.toDF()

val testData = test.toDF()

val nFolds: Int = 5
val NumTrees: Int = 5

val rf = new
RandomForestClassifier()
.setLabelCol("label")
.setFeaturesCol("features")
.setNumTrees(NumTrees)

val pipeline = new Pipeline()
.setStages(Array(rf))

val paramGrid = new ParamGridBuilder()
.build()

val evaluator = new MulticlassClassificationEvaluator()
.setLabelCol("label")
.setPredictionCol("prediction")
.setMetricName("precision")

val cv = new CrossValidator()
.setEstimator(pipeline)
.setEvaluator(evaluator)
.setEstimatorParamMaps(paramGrid)
.setNumFolds(nFolds)

val cvModel = cv.fit(trainingData)

val results = cvModel.transform(testData)
.select("label","prediction").collect

val numCorrectPredictions = results.map(row =>
if (row.getDouble(0) == row.getDouble(1)) 1 else 0).foldLeft(0)(_ + _)
val accuracy = 1.0D * numCorrectPredictions / results.size

println("Test set accuracy: %.3f".format(accuracy))

谁能解释一下上面代码中的错误是什么?

最佳答案

RandomForestClassifier 与许多其他 ML 算法相同,需要在标签列上设置特定元数据,并且标签值必须是 [0, 1, 2 ..., #classes) 中的整数值表示为 double 。通常,这是由上游 Transformers(例如 StringIndexer)处理的。由于您手动转换标签,因此未设置元数据字段,并且分类器无法确认是否满足这些要求。

val df = Seq(
(0.0, Vectors.dense(1, 0, 0, 0)),
(1.0, Vectors.dense(0, 1, 0, 0)),
(2.0, Vectors.dense(0, 0, 1, 0)),
(2.0, Vectors.dense(0, 0, 0, 1))
).toDF("label", "features")

val rf = new RandomForestClassifier()
.setFeaturesCol("features")
.setNumTrees(5)

rf.setLabelCol("label").fit(df)
// java.lang.IllegalArgumentException: RandomForestClassifier was given input ...

您可以使用StringIndexer重新编码标签列:

import org.apache.spark.ml.feature.StringIndexer

val indexer = new StringIndexer()
.setInputCol("label")
.setOutputCol("label_idx")
.fit(df)

rf.setLabelCol("label_idx").fit(indexer.transform(df))

set required metadata manually :

val meta = NominalAttribute
.defaultAttr
.withName("label")
.withValues("0.0", "1.0", "2.0")
.toMetadata

rf.setLabelCol("label_meta").fit(
df.withColumn("label_meta", $"label".as("", meta))
)

注意:

使用StringIndexer创建的标签取决于频率而不是值:

indexer.labels
// Array[String] = Array(2.0, 0.0, 1.0)

PySpark:

在 Python 中,元数据字段可以直接在架构上设置:

from pyspark.sql.types import StructField, DoubleType

StructField(
"label", DoubleType(), False,
{"ml_attr": {
"name": "label",
"type": "nominal",
"vals": ["0.0", "1.0", "2.0"]
}}
)

关于scala - 在 Apache Spark 中,RandomForestClassifier 的输入带有无效标签列错误,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/36517302/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com