gpt4 book ai didi

java - Spark 2.1.0 - SparkML 要求失败

转载 作者:行者123 更新时间:2023-12-02 12:28:14 25 4
gpt4 key购买 nike

我正在研究 Spark 2.1.0 Kmeans - 聚类算法。

public class ClusteringTest {
public static void main(String[] args) {
SparkSession session = SparkSession.builder()
.appName("Clustering Test")
.config("spark.master", "local")
.getOrCreate();
session.sparkContext().setLogLevel("ERROR");

List<Row> rawDataTraining = Arrays.asList(
RowFactory.create(1.0,Vectors.dense( 1.0, 1.0, 1.0).toSparse()),
RowFactory.create(1.0,Vectors.dense(2.0, 2.0, 2.0).toSparse()),
RowFactory.create(1.0,Vectors.dense(3.0, 3.0, 3.0).toSparse()),

RowFactory.create(2.0,Vectors.dense(6.0, 6.0, 6.0).toSparse()),
RowFactory.create(2.0,Vectors.dense(7.0, 7.0, 7.0).toSparse()),
RowFactory.create(2.0,Vectors.dense(8.0, 8.0,8.0).toSparse()),
//...
StructType schema = new StructType(new StructField[]{

new StructField("label", DataTypes.DoubleType, false, Metadata.empty()),
new StructField("features", new VectorUDT(), false, Metadata.empty())
});

Dataset<Row> myRawData = session.createDataFrame(rawDataTraining, schema);
Dataset<Row>[] splits = myRawData.randomSplit(new double[]{0.75, 0.25});
Dataset<Row> trainingData = splits[0];
Dataset<Row> testData = splits[1];

//Train Kmeans
KMeans kMeans = new KMeans().setK(3).setSeed(100);
KMeansModel kMeansModel = kMeans.fit(trainingData);
Dataset<Row> predictions = kMeansModel.transform(testData);
predictions.show(false);
MulticlassClassificationEvaluator evaluator = new MulticlassClassificationEvaluator().setLabelCol("label")
.setPredictionCol("prediction")
.setMetricName("accuracy");
double accuracy = evaluator.evaluate(predictions);
System.out.println("accuracy" + accuracy);
}
}

控制台输出是:

+-----+----------------------------+----------+
|label|features |prediction|
+-----+----------------------------+----------+
|2.0 |(3,[0,1,2],[7.0,7.0,7.0]) |2 |
|3.0 |(3,[0,1,2],[11.0,11.0,11.0])|2 |
|3.0 |(3,[0,1,2],[12.0,12.0,12.0])|1 |
|3.0 |(3,[0,1,2],[13.0,13.0,13.0])|1 |
+-----+----------------------------+----------+

Exception in thread "main" java.lang.IllegalArgumentException: requirement failed: Column prediction must be of type DoubleType but was actually IntegerType.
at scala.Predef$.require(Predef.scala:233)
at org.apache.spark.ml.util.SchemaUtils$.checkColumnType(SchemaUtils.scala:42)
at org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator.evaluate(MulticlassClassificationEvaluator.scala:75)
at ClusteringTest.main(ClusteringTest.java:84)

Process finished with exit code 1

正如你所看到的,预测结果是整数。但要使用 MulticlassClassificationEvalutor,我需要将这些预测结果转换为 Double。我该怎么做?

最佳答案

TL;DR这不是正确的方法。

KMeans 是无监督方法,您获得的集群标识符是任意的(集群 ID 可以排列),并且与 label 列无关。因此,使用 MulticlassClassificationEvaluator 来比较现有标签和 KMeans 的输出没有任何意义。

您应该使用一些监督分类器,例如多项逻辑回归或朴素贝叶斯。

如果您想坚持使用 KMeans 请使用适当的质量指标,例如 computeCost 返回的指标,但这会完全忽略标签信息。

关于java - Spark 2.1.0 - SparkML 要求失败,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/45392303/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com