gpt4 book ai didi

scala - 为什么我的 DecisionTreeClassifier 模型在预测时提示 labelCol 不存在?

转载 作者:行者123 更新时间:2023-11-30 09:44:07 27 4
gpt4 key购买 nike

我开始编写一个 ML 模型,用于对一系列文档中的段落进行分类。我写了我的模型,结果看起来很棒!但是,当我尝试输入不包含 labelCol 的 CSV(即标记列,我试图预测的列)时,它会抛出错误! “tagIndexed 字段不存在。”

所以这很奇怪。我试图预测的是“tag”列,那么当我调用model.transform(df)时为什么会期望“tagIndexed”列呢? (在 Predict.scala 中)?我对 ML 没有经验,但所有 DecisionTreeClassifier 往往在测试数据中不存在 labelCol。我在这里缺少什么?

我创建了模型,使用测试数据对其进行验证,并将其保存到磁盘。然后,在另一个 Scala 对象中,我加载模型并将 csv 传递到其中。

//Train.scala    
package com.secret.classifier
import org.apache.spark.ml.Pipeline
import org.apache.spark.ml.classification.DecisionTreeClassifier
import org.apache.spark.ml.evaluation.RegressionEvaluator
import org.apache.spark.sql.Column
import org.apache.spark.ml.feature.{HashingTF, IDF, StringIndexer, Tokenizer, VectorAssembler, Word2Vec}
import org.apache.spark.ml.regression.LinearRegression
import org.apache.spark.ml.tuning.{ParamGridBuilder, TrainValidationSplit}
import org.apache.spark.sql.functions.udf
import org.apache.spark.sql.types
import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType}

...

val colSeq = Seq("font", "tag")
val indexSeq = colSeq.map(col => new StringIndexer().setInputCol(col).setOutputCol(col+"Indexed").fit(dfNoNan))

val tokenizer = new Tokenizer().setInputCol("soup").setOutputCol("words")
//val wordsData = tokenizer.transform(dfNoNan)

val hashingTF = new HashingTF()
.setInputCol(tokenizer.getOutputCol)
.setOutputCol("rawFeatures")
.setNumFeatures(20)

val featuresCol = "features"
val assembler = new VectorAssembler()
.setInputCols((numericCols ++ colSeq.map(_+"Indexed")).toArray)
.setOutputCol(featuresCol)

val labelCol = "tagIndexed"
val decisionTree = new DecisionTreeClassifier()
.setLabelCol(labelCol)
.setFeaturesCol(featuresCol)

val pipeline = new Pipeline().setStages((indexSeq :+ tokenizer :+ hashingTF :+ assembler :+ decisionTree).toArray)

val Array(training, test) = dfNoNan.randomSplit(Array(0.8, 0.2), seed=420420)

val model = pipeline.fit(training)


model.write.overwrite().save("tmp/spark-model")

//Predict.scala
package com.secret.classifier
import org.apache.spark.sql.functions._
import org.apache.spark.ml.{Pipeline, PipelineModel}
import org.apache.spark.ml.classification.DecisionTreeClassifier
import org.apache.spark.ml.evaluation.RegressionEvaluator
import org.apache.spark.sql.Column
import org.apache.spark.ml.feature.{HashingTF, IDF, StringIndexer, Tokenizer, VectorAssembler, Word2Vec}
import org.apache.spark.ml.regression.LinearRegression
import org.apache.spark.ml.tuning.{ParamGridBuilder, TrainValidationSplit}
import org.apache.spark.sql.types
import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType}

...

val dfImport = spark.read
.format("csv")
.option("header", "true")
//.option("mode", "DROPMALFORMED")
.schema(customSchema)
.load(csvLocation)

val df = dfImport.drop("_c0", "doc_name")
df.show(20)

val model = PipelineModel.load("tmp/spark-model")

val predictions = model.transform(df)

predictions.show(20)


//pom.xml -> Spark/Scala specific dependencies
<properties>
<maven.compiler.source>1.8</maven.compiler.source>
<maven.compiler.target>1.8</maven.compiler.target>
<encoding>UTF-8</encoding>
<scala.version>2.11.12</scala.version>
<scala.compat.version>2.11</scala.compat.version>
<spec2.version>4.2.0</spec2.version>
</properties>
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-core_2.11</artifactId>
<version>2.3.1</version>
</dependency>

<!-- https://mvnrepository.com/artifact/com.databricks/spark-csv -->
<dependency>
<groupId>com.databricks</groupId>
<artifactId>spark-csv_2.11</artifactId>
<version>1.5.0</version>
</dependency>

<!-- https://mvnrepository.com/artifact/org.apache.spark/spark-sql -->
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-sql_2.11</artifactId>
<version>2.3.1</version>
</dependency>

<!-- https://mvnrepository.com/artifact/org.apache.spark/spark-core -->
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-core_2.11</artifactId>
<version>2.3.1</version>
</dependency>

<dependency>
<groupId>com.univocity</groupId>
<artifactId>univocity-parsers</artifactId>
<version>2.8.0</version>
</dependency>

<!-- https://mvnrepository.com/artifact/org.apache.spark/spark-mllib -->
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-mllib_2.11</artifactId>
<version>2.3.1</version>
</dependency>
</dependencies>

预期结果是预测模型不会抛出错误。相反,它会抛出错误“字段“tagIndexed”不存在。”

最佳答案

看起来您已经在功能中包含了标签字段,因为它位于 colSeq 列输出中。在此步骤中,您只想包含特征列:

.setInputCols((numericCols ++ colSeq.map(_+"Indexed")).toArray)

我发现使用 .filterNot() 函数很有帮助。

关于scala - 为什么我的 DecisionTreeClassifier 模型在预测时提示 labelCol 不存在?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/54753097/

27 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com