gpt4 book ai didi

apache-spark - 为什么 pyspark 的 BinaryClassificationEvaluator avgMetrics 返回一个大于 1 的值?

转载 作者:行者123 更新时间:2023-12-04 12:50:40 25 4
gpt4 key购买 nike

    evaluator = BinaryClassificationEvaluator()
grid = ParamGridBuilder().build() # no hyper parameter optimization
cv = CrossValidator(estimator=pipeline, estimatorParamMaps=grid, evaluator=evaluator)
cvModel = cv.fit(dataset)
evaluator.evaluate(cvModel.transform(dataset))

返回:

  • cvModel.avgMetrics = [1.602872634746238]
  • evaluator.evaluate(cvModel.transform(dataset)) = 0.7267754950388204

问题:

  1. 如果 avgMetric 是 ROC 下的面积,怎么会大于 1 (1.6)?
  2. 方案 evaluator.evaluate(cvModel.transform(dataset)) 实际上返回的是训练指标而不是交叉验证指标吗? (我们使用 dataset 进行拟合和评估)

最佳答案

这是一个错误 fixed最近。但是,它尚未发布。

根据您提供的内容,我使用以下代码重现了该问题:

from pyspark.ml.classification import LogisticRegression
from pyspark.ml.evaluation import BinaryClassificationEvaluator
from pyspark.ml.tuning import ParamGridBuilder, CrossValidator
from pyspark.ml.linalg import Vectors
from pyspark.sql.types import Row

dataset = sc.parallelize([
Row(features=Vectors.dense([1., 0.]), label=1.),
Row(features=Vectors.dense([1., 1.]), label=0.),
Row(features=Vectors.dense([0., 0.]), label=1.),
]).toDF()

evaluator = BinaryClassificationEvaluator(metricName="areaUnderROC")
grid = ParamGridBuilder().addGrid('maxIter', [100, 10]).build() # no hyper parameter optimization
cv = CrossValidator(estimator=LogisticRegression(), estimatorParamMaps=grid, evaluator=evaluator)
cvModel = cv.fit(dataset)
evaluator.evaluate(cvModel.transform(dataset))

Out[23]: 1.0

cvModel.avgMetrics

Out[34]: [2.0, 2.0]

简单来说,

avgMetrics was summed, not averaged, across folds

编辑:

关于第二个问题,最简单的验证方法是提供测试数据集:

to_test = sc.parallelize([
Row(features=Vectors.dense([1., 0.]), label=1.),
Row(features=Vectors.dense([1., 1.]), label=0.),
Row(features=Vectors.dense([0., 1.]), label=1.),
]).toDF()

evaluator.evaluate(cvModel.transform(to_test))

Out[2]: 0.5

它确认函数调用返回测试数据集上的指标。

关于apache-spark - 为什么 pyspark 的 BinaryClassificationEvaluator avgMetrics 返回一个大于 1 的值?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/39331375/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com