gpt4 book ai didi

scala - Spark ML 将预测标签转换为字符串,无需训练 DataFrame

转载 作者:行者123 更新时间:2023-11-30 08:39:09 25 4
gpt4 key购买 nike

我在 Apache Spark ML(版本 2.1.0)中使用 NaiveBayes 多项式分类器来预测一些文本类别。

问题是如何在没有经过训练的 DataFrame 的情况下将预测标签(0.0、1.0、2.0)转换为字符串。

我知道可以使用 IndexToString,但只有在训练和预测同时进行时才有用。但是,就我而言,它是独立的工作。

代码看起来像
1) TrainingModel.scala : 训练模型并将模型保存在文件中。
2) CategoryPrediction.scala:从文件加载训练好的模型并对测试数据进行预测。

请提出解决方案:

TrainingModel.scala

val trainData: Dataset[LabeledRecord] = spark.read.option("inferSchema", "false")
.schema(schema).csv("trainingdata1.csv").as[LabeledRecord]

val labelIndexer = new StringIndexer().setInputCol("category").setOutputCol("label").fit(trainData).setHandleInvalid("skip")

val tokenizer = new RegexTokenizer().setInputCol("text").setOutputCol("words")

val hashingTF = new HashingTF()
.setInputCol("words")
.setOutputCol("features")
.setNumFeatures(1000)

val rf = new NaiveBayes().setLabelCol("label").setFeaturesCol("features").setModelType("multinomial")

val pipeline = new Pipeline().setStages(Array(tokenizer, hashingTF, labelIndexer, rf))

val model = pipeline.fit(trainData)

model.write.overwrite().save("naivebayesmodel");

CategoryPrediction.scala

val testData: Dataset[PredictLabeledRecord] = spark.read.option("inferSchema", "false")
.schema(predictSchema).csv("testingdata.csv").as[PredictLabeledRecord]

val model = PipelineModel.load("naivebayesmodel")

val predictions = model.transform(testData)

// val labelConverter = new IndexToString()
// .setInputCol("prediction")
// .setOutputCol("predictedLabelString")
// .setLabels(trainDataFrameIndexer.labels)

predictions.select("prediction", "text").show(false)

trainingdata1.csv

category,text
Drama,"a b c d e spark"
Action,"b d"
Horror,"spark f g h"
Thriller,"hadoop mapreduce"

testingdata.csv

text
"a b c d e spark"
"spark f g h"

最佳答案

添加一个转换器,将预测类别转换回管道中的标签,如下所示:

val categoryConverter = new IndexToString()
.setInputCol("prediction")
.setOutputCol("category")
.setLabels(labelIndexer.labels)

val pipeline = new Pipeline().setStages(Array(tokenizer, hashingTF, labelIndexer, rf, categoryConverter))

这将进行预测并使用 labelIndexer 将其转换回标签。

关于scala - Spark ML 将预测标签转换为字符串,无需训练 DataFrame,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/44997233/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com