gpt4 book ai didi

java - 如何从 Java 中的 Spark MLLib Logistic Regression 获取置信度分数

转载 作者:太空宇宙 更新时间:2023-11-04 12:16:38 25 4
gpt4 key购买 nike

更新:我尝试使用以下方法来生成置信度分数,但它给了我一个异常(exception)。我使用下面的代码片段:

double point = BLAS.dot(logisticregressionmodel.weights(), datavector);
double confScore = 1.0 / (1.0 + Math.exp(-point));

我得到的异常是:

Caused by: java.lang.IllegalArgumentException: requirement failed: BLAS.dot(x: Vector, y:Vector) was given Vectors with non-matching sizes: x.size = 198, y.size = 18
at scala.Predef$.require(Predef.scala:233)
at org.apache.spark.mllib.linalg.BLAS$.dot(BLAS.scala:99)
at org.apache.spark.mllib.linalg.BLAS.dot(BLAS.scala)

你能帮忙吗?看起来权重 vector 比数据 vector 有更多的元素(198)(我正在生成 18 个特征)。它们在 dot() 函数中的长度必须相同。

我正在尝试用 Java 实现一个程序,以从现有数据集进行训练,并使用 Spark MLLib (1.5.0) 中提供的逻辑回归算法对新数据集进行预测。我的训练和预测程序如下,我正在使用多类实现。问题是当我执行 model.predict(vector) 时(注意预测程序中的 lrmodel.predict()),我得到了预测标签。但如果我需要置信度分数怎么办?我怎样才能得到它?我已经浏览了 API,但无法找到任何给出置信度分数的特定 API。谁能帮帮我吗?

训练程序(生成.model文件)

public static void main(final String[] args) throws Exception {
JavaSparkContext jsc = null;
int salesIndex = 1;

try {
...
SparkConf sparkConf =
new SparkConf().setAppName("Hackathon Train").setMaster(
sparkMaster);
jsc = new JavaSparkContext(sparkConf);
...

JavaRDD<String> trainRDD = jsc.textFile(basePath + "old-leads.csv").cache();

final String firstRdd = trainRDD.first().trim();
JavaRDD<String> tempRddFilter =
trainRDD.filter(new org.apache.spark.api.java.function.Function<String, Boolean>() {
private static final long serialVersionUID =
11111111111111111L;

public Boolean call(final String arg0) {
return !arg0.trim().equalsIgnoreCase(firstRdd);
}
});

...
JavaRDD<String> featureRDD =
tempRddFilter
.map(new org.apache.spark.api.java.function.Function() {
private static final long serialVersionUID =
6948900080648474074L;

public Object call(final Object arg0)
throws Exception {
...
StringBuilder featureSet =
new StringBuilder();
...
featureSet.append(i - 2);
featureSet.append(COLON);
featureSet.append(strVal);
featureSet.append(SPACE);
}

return featureSet.toString().trim();
}
});

List<String> featureList = featureRDD.collect();
String featureOutput = StringUtils.join(featureList, NEW_LINE);
String filePath = basePath + "lr.arff";
FileUtils.writeStringToFile(new File(filePath), featureOutput,
"UTF-8");

JavaRDD<LabeledPoint> trainingData =
MLUtils.loadLibSVMFile(jsc.sc(), filePath).toJavaRDD().cache();

final LogisticRegressionModel model =
new LogisticRegressionWithLBFGS().setNumClasses(18).run(
trainingData.rdd());
ByteArrayOutputStream baos = new ByteArrayOutputStream();
ObjectOutputStream oos = new ObjectOutputStream(baos);
oos.writeObject(model);
oos.flush();
oos.close();
FileUtils.writeByteArrayToFile(new File(basePath + "lr.model"),
baos.toByteArray());
baos.close();

} catch (Exception e) {
e.printStackTrace();
} finally {
if (jsc != null) {
jsc.close();
}
}

预测程序(使用训练程序生成的lr.model)

    public static void main(final String[] args) throws Exception {
JavaSparkContext jsc = null;
int salesIndex = 1;
try {
...
SparkConf sparkConf =
new SparkConf().setAppName("Hackathon Predict").setMaster(sparkMaster);
jsc = new JavaSparkContext(sparkConf);

ObjectInputStream objectInputStream =
new ObjectInputStream(new FileInputStream(basePath
+ "lr.model"));
LogisticRegressionModel lrmodel =
(LogisticRegressionModel) objectInputStream.readObject();
objectInputStream.close();

...

JavaRDD<String> trainRDD = jsc.textFile(basePath + "new-leads.csv").cache();

final String firstRdd = trainRDD.first().trim();
JavaRDD<String> tempRddFilter =
trainRDD.filter(new org.apache.spark.api.java.function.Function<String, Boolean>() {
private static final long serialVersionUID =
11111111111111111L;

public Boolean call(final String arg0) {
return !arg0.trim().equalsIgnoreCase(firstRdd);
}
});

...
final Broadcast<LogisticRegressionModel> broadcastModel =
jsc.broadcast(lrmodel);

JavaRDD<String> featureRDD =
tempRddFilter
.map(new org.apache.spark.api.java.function.Function() {
private static final long serialVersionUID =
6948900080648474074L;

public Object call(final Object arg0)
throws Exception {
...
LogisticRegressionModel lrModel =
broadcastModel.value();
String row = ((String) arg0);
String[] featureSetArray =
row.split(CSV_SPLITTER);
...
final Vector vector =
Vectors.dense(doubleArr);
double score = lrModel.predict(vector);
...
return csvString;
}
});

String outputContent =
featureRDD
.reduce(new org.apache.spark.api.java.function.Function2() {

private static final long serialVersionUID =
1212970144641935082L;

public Object call(Object arg0, Object arg1)
throws Exception {
...
}

});
...
FileUtils.writeStringToFile(new File(basePath
+ "predicted-sales-data.csv"), sb.toString());
} catch (Exception e) {
e.printStackTrace();
} finally {
if (jsc != null) {
jsc.close();
}
}
}
}

最佳答案

经过多次尝试,我终于成功编写了一个自定义函数来生成置信度分数。它一点也不完美,但目前对我有用!

private static double getConfidenceScore(
final LogisticRegressionModel lrModel, final Vector vector) {
/* Approach to get confidence scores starts */
Vector weights = lrModel.weights();
int numClasses = lrModel.numClasses();
int dataWithBiasSize = weights.size() / (numClasses - 1);
boolean withBias = (vector.size() + 1) == dataWithBiasSize;
double maxMargin = 0.0;
double margin = 0.0;
for (int j = 0; j < (numClasses - 1); j++) {
margin = 0.0;
for (int k = 0; k < vector.size(); k++) {
double value = vector.toArray()[k];
if (value != 0.0) {
margin += value
* weights.toArray()[(j * dataWithBiasSize) + k];
}
}
if (withBias) {
margin += weights.toArray()[(j * dataWithBiasSize)
+ vector.size()];
}
if (margin > maxMargin) {
maxMargin = margin;
}
}
double conf = 1.0 / (1.0 + Math.exp(-maxMargin));
DecimalFormat twoDForm = new DecimalFormat("#.##");
double confidenceScore = Double.valueOf(twoDForm.format(conf * 100));
/* Approach to get confidence scores ends */
return confidenceScore;
}

关于java - 如何从 Java 中的 Spark MLLib Logistic Regression 获取置信度分数,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/39328601/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com