gpt4 book ai didi

java - Apache Spark 机器学习 - 无法让 Estimator 示例工作

转载 作者:行者123 更新时间:2023-12-04 20:44:09 24 4
gpt4 key购买 nike

我很难从 Spark 文档中获取任何示例机器学习代码并实际让它们作为 Java 程序运行。无论是我对 Java、Maven、Spark(或者很可能是这三者)的了解有限,我都找不到有用的解释。

拿这个example .为了尝试让这项工作正常进行,我使用了以下项目结构

.
├── pom.xml
└── src
└── main
└── java
└── SimpleEstimator.java

Java 文件如下所示

import java.util.Arrays;
import java.util.List;

import org.apache.spark.ml.classification.LogisticRegressionModel;
import org.apache.spark.ml.param.ParamMap;
import org.apache.spark.ml.classification.LogisticRegression;
import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.sql.DataFrame;
import org.apache.spark.sql.Row;


public class SimpleEstimator {
public static void main(String[] args) {
DataFrame training = sqlContext.createDataFrame(Arrays.asList(
new LabeledPoint(1.0, Vectors.dense(0.0, 1.1, 0.1)),
new LabeledPoint(0.0, Vectors.dense(2.0, 1.0, -1.0)),
new LabeledPoint(0.0, Vectors.dense(2.0, 1.3, 1.0)),
new LabeledPoint(1.0, Vectors.dense(0.0, 1.2, -0.5))
), LabeledPoint.class);

LogisticRegression lr = new LogisticRegression();
System.out.println("LogisticRegression parameters:\n" + lr.explainParams() + "\n");

lr.setMaxIter(10)
.setRegParam(0.01);

LogisticRegressionModel model1 = lr.fit(training);

System.out.println("Model 1 was fit using parameters: " + model1.parent().extractParamMap());

ParamMap paramMap = new ParamMap()
.put(lr.maxIter().w(20)) // Specify 1 Param.
.put(lr.maxIter(), 30) // This overwrites the original maxIter.
.put(lr.regParam().w(0.1), lr.threshold().w(0.55)); // Specify multiple Params.

ParamMap paramMap2 = new ParamMap()
.put(lr.probabilityCol().w("myProbability")); // Change output column name
ParamMap paramMapCombined = paramMap.$plus$plus(paramMap2);

LogisticRegressionModel model2 = lr.fit(training, paramMapCombined);
System.out.println("Model 2 was fit using parameters: " + model2.parent().extractParamMap());

DataFrame test = sqlContext.createDataFrame(Arrays.asList(
new LabeledPoint(1.0, Vectors.dense(-1.0, 1.5, 1.3)),
new LabeledPoint(0.0, Vectors.dense(3.0, 2.0, -0.1)),
new LabeledPoint(1.0, Vectors.dense(0.0, 2.2, -1.5))
), LabeledPoint.class);


DataFrame results = model2.transform(test);
for (Row r: results.select("features", "label", "myProbability", "prediction").collect()) {
System.out.println("(" + r.get(0) + ", " + r.get(1) + ") -> prob=" + r.get(2)
+ ", prediction=" + r.get(3));
}
}
}

pom文件如下

<project>
<groupId>edu.berkeley</groupId>
<artifactId>simple-estimator</artifactId>
<modelVersion>4.0.0</modelVersion>
<name>Simple Estimator</name>
<packaging>jar</packaging>
<version>1.0</version>
<dependencies>
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-core_2.11</artifactId>
<version>1.5.0</version>
</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-mllib_2.11</artifactId>
<version>1.5.0</version>
</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-sql_2.11</artifactId>
<version>1.5.0</version>
</dependency>
</dependencies>
</project>

如果我从该目录的根目录运行 mvn package,我会得到以下错误

[INFO] Scanning for projects...
[INFO]
[INFO] ------------------------------------------------------------------------
[INFO] Building Simple Estimator 1.0
[INFO] ------------------------------------------------------------------------
[INFO]
[INFO] --- maven-resources-plugin:2.6:resources (default-resources) @ simple-estimator ---
[WARNING] Using platform encoding (UTF-8 actually) to copy filtered resources, i.e. build is platform dependent!
[INFO] skip non existing resourceDirectory /Users/philip/study/spark/estimator/src/main/resources
[INFO]
[INFO] --- maven-compiler-plugin:3.1:compile (default-compile) @ simple-estimator ---
[INFO] Changes detected - recompiling the module!
[WARNING] File encoding has not been set, using platform encoding UTF-8, i.e. build is platform dependent!
[INFO] Compiling 1 source file to /Users/philip/study/spark/estimator/target/classes
[INFO] -------------------------------------------------------------
[ERROR] COMPILATION ERROR :
[INFO] -------------------------------------------------------------
[ERROR] /Users/philip/study/spark/estimator/src/main/java/SimpleEstimator.java:[15,26] cannot find symbol
symbol: variable sqlContext
location: class SimpleEstimator
[ERROR] /Users/philip/study/spark/estimator/src/main/java/SimpleEstimator.java:[44,22] cannot find symbol
symbol: variable sqlContext
location: class SimpleEstimator
[INFO] 2 errors
[INFO] -------------------------------------------------------------
[INFO] ------------------------------------------------------------------------
[INFO] BUILD FAILURE
[INFO] ------------------------------------------------------------------------
[INFO] Total time: 1.567 s
[INFO] Finished at: 2015-09-16T16:54:20+01:00
[INFO] Final Memory: 36M/422M
[INFO] ------------------------------------------------------------------------
[ERROR] Failed to execute goal org.apache.maven.plugins:maven-compiler-plugin:3.1:compile (default-compile) on project simple-estimator: Compilation failure: Compilation failure:
[ERROR] /Users/philip/study/spark/estimator/src/main/java/SimpleEstimator.java:[15,26] cannot find symbol
[ERROR] symbol: variable sqlContext
[ERROR] location: class SimpleEstimator
[ERROR] /Users/philip/study/spark/estimator/src/main/java/SimpleEstimator.java:[44,22] cannot find symbol
[ERROR] symbol: variable sqlContext
[ERROR] location: class SimpleEstimator
[ERROR] -> [Help 1]
[ERROR]
[ERROR] To see the full stack trace of the errors, re-run Maven with the -e switch.
[ERROR] Re-run Maven using the -X switch to enable full debug logging.
[ERROR]
[ERROR] For more information about the errors and possible solutions, please read the following articles:
[ERROR] [Help 1] http://cwiki.apache.org/confluence/display/MAVEN/MojoFailureException

更新

感谢@holden,我确保添加了这些行

// additional imports
import org.apache.spark.api.java.*;
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.sql.SQLContext;

// added these as starting lines in class
SparkConf conf = new SparkConf().setAppName("Simple Estimator");
JavaSparkContext sc = new JavaSparkContext(conf);
SQLContext sqlContext = new org.apache.spark.sql.SQLContext(sc);

这让事情有了一些进展,但现在我得到了以下错误

[ERROR] Failed to execute goal org.apache.maven.plugins:maven-compiler-plugin:3.1:compile (default-compile) on project simple-estimator: Compilation failure
[ERROR] /Users/philip/study/spark/estimator/src/main/java/SimpleEstimator.java:[21,36] no suitable method found for createDataFrame(java.util.List<org.apache.spark.mllib.regression.LabeledPoint>,java.lang.Class<org.apache.spark.mllib.regression.LabeledPoint>)
[ERROR] method org.apache.spark.sql.SQLContext.<A>createDataFrame(org.apache.spark.rdd.RDD<A>,scala.reflect.api.TypeTags.TypeTag<A>) is not applicable
[ERROR] (cannot infer type-variable(s) A
[ERROR] (argument mismatch; java.util.List<org.apache.spark.mllib.regression.LabeledPoint> cannot be converted to org.apache.spark.rdd.RDD<A>))
[ERROR] method org.apache.spark.sql.SQLContext.<A>createDataFrame(scala.collection.Seq<A>,scala.reflect.api.TypeTags.TypeTag<A>) is not applicable
[ERROR] (cannot infer type-variable(s) A
[ERROR] (argument mismatch; java.util.List<org.apache.spark.mllib.regression.LabeledPoint> cannot be converted to scala.collection.Seq<A>))
[ERROR] method org.apache.spark.sql.SQLContext.createDataFrame(org.apache.spark.rdd.RDD<org.apache.spark.sql.Row>,org.apache.spark.sql.types.StructType) is not applicable
[ERROR] (argument mismatch; java.util.List<org.apache.spark.mllib.regression.LabeledPoint> cannot be converted to org.apache.spark.rdd.RDD<org.apache.spark.sql.Row>)
[ERROR] method org.apache.spark.sql.SQLContext.createDataFrame(org.apache.spark.api.java.JavaRDD<org.apache.spark.sql.Row>,org.apache.spark.sql.types.StructType) is not applicable
[ERROR] (argument mismatch; java.util.List<org.apache.spark.mllib.regression.LabeledPoint> cannot be converted to org.apache.spark.api.java.JavaRDD<org.apache.spark.sql.Row>)
[ERROR] method org.apache.spark.sql.SQLContext.createDataFrame(org.apache.spark.rdd.RDD<?>,java.lang.Class<?>) is not applicable
[ERROR] (argument mismatch; java.util.List<org.apache.spark.mllib.regression.LabeledPoint> cannot be converted to org.apache.spark.rdd.RDD<?>)
[ERROR] method org.apache.spark.sql.SQLContext.createDataFrame(org.apache.spark.api.java.JavaRDD<?>,java.lang.Class<?>) is not applicable
[ERROR] (argument mismatch; java.util.List<org.apache.spark.mllib.regression.LabeledPoint> cannot be converted to org.apache.spark.api.java.JavaRDD<?>)

错误引用的代码直接来自示例

DataFrame training = sqlContext.createDataFrame(Arrays.asList(
new LabeledPoint(1.0, Vectors.dense(0.0, 1.1, 0.1)),
new LabeledPoint(0.0, Vectors.dense(2.0, 1.0, -1.0)),
new LabeledPoint(0.0, Vectors.dense(2.0, 1.3, 1.0)),
new LabeledPoint(1.0, Vectors.dense(0.0, 1.2, -0.5))
), LabeledPoint.class);

最佳答案

示例通常不会创建 sqlContextsc(或 SparkContext),因为它们对于每个示例都是相同的。 http://spark.apache.org/docs/latest/sql-programming-guide.html有如何创建 sqlContexthttp://spark.apache.org/docs/latest/quick-start.html有如何创建 sc(或 SparkContext)。

你可能需要这样的东西:

更多导入:

//Additional imports
import org.apache.spark.api.java.*;
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.function.Function;

在您的主要方法的开头添加:

// In your method:
SparkConf conf = new SparkConf().setAppName("Simple Application");
JavaSparkContext sc = new JavaSparkContext(conf);
SQLContext sqlContext = new org.apache.spark.sql.SQLContext(sc);

根据您的更新,您遇到的第二个问题是创建 DataFrame(再次被排除在 Java 示例之外)。您尝试使用的方法尚未实现(事实上,我有一个待处理的拉取请求以在 https://github.com/apache/spark/pull/8779 实现类似的东西,尽管该版本需要 Row's and & Schemas,我添加了一个 JIRA https://issues.apache.org/jira/browse/SPARK-10720 来跟踪添加这个用于本地 JavaBean 解决方案)。

谢天谢地,这个额外的步骤并不是我们要用的所有代码:

   DataFrame test = sqlContext.createDataFrame(Arrays.asList(
new LabeledPoint(1.0, Vectors.dense(-1.0, 1.5, 1.3)),
new LabeledPoint(0.0, Vectors.dense(3.0, 2.0, -0.1)),
new LabeledPoint(1.0, Vectors.dense(0.0, 2.2, -1.5))
), LabeledPoint.class);

改为:

   DataFrame test = sqlContext.createDataFrame(sc.parallelize(
Arrays.asList(
new LabeledPoint(1.0, Vectors.dense(-1.0, 1.5, 1.3)),
new LabeledPoint(0.0, Vectors.dense(3.0, 2.0, -0.1)),
new LabeledPoint(1.0, Vectors.dense(0.0, 2.2, -1.5))
)), LabeledPoint.class);

关于java - Apache Spark 机器学习 - 无法让 Estimator 示例工作,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/32613413/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com