gpt4 book ai didi

pyspark - 如何在PySpark中从spark.ml提取模型超参数?

转载 作者:行者123 更新时间:2023-12-03 11:36:51 25 4
gpt4 key购买 nike

我正在修改PySpark文档中的一些交叉验证代码,并试图让PySpark告诉我选择了哪种模型:

from pyspark.ml.classification import LogisticRegression
from pyspark.ml.evaluation import BinaryClassificationEvaluator
from pyspark.mllib.linalg import Vectors
from pyspark.ml.tuning import ParamGridBuilder, CrossValidator

dataset = sqlContext.createDataFrame(
[(Vectors.dense([0.0]), 0.0),
(Vectors.dense([0.4]), 1.0),
(Vectors.dense([0.5]), 0.0),
(Vectors.dense([0.6]), 1.0),
(Vectors.dense([1.0]), 1.0)] * 10,
["features", "label"])
lr = LogisticRegression()
grid = ParamGridBuilder().addGrid(lr.regParam, [0.1, 0.01, 0.001, 0.0001]).build()
evaluator = BinaryClassificationEvaluator()
cv = CrossValidator(estimator=lr, estimatorParamMaps=grid, evaluator=evaluator)
cvModel = cv.fit(dataset)

在PySpark shell中运行此代码,可以得到线性回归模型的系数,但似乎找不到交叉验证过程选择的 lr.regParam的值。有任何想法吗?
In [3]: cvModel.bestModel.coefficients
Out[3]: DenseVector([3.1573])

In [4]: cvModel.bestModel.explainParams()
Out[4]: ''

In [5]: cvModel.bestModel.extractParamMap()
Out[5]: {}

In [15]: cvModel.params
Out[15]: []

In [36]: cvModel.bestModel.params
Out[36]: []

最佳答案

也遇到这个问题。我发现由于某些原因(我不知道为什么)需要调用java属性。因此,只需执行以下操作:

from pyspark.ml.tuning import TrainValidationSplit, ParamGridBuilder, CrossValidator
from pyspark.ml.regression import LinearRegression
from pyspark.ml.evaluation import RegressionEvaluator

evaluator = RegressionEvaluator(metricName="mae")
lr = LinearRegression()
grid = ParamGridBuilder().addGrid(lr.maxIter, [500]) \
.addGrid(lr.regParam, [0]) \
.addGrid(lr.elasticNetParam, [1]) \
.build()
lr_cv = CrossValidator(estimator=lr, estimatorParamMaps=grid, \
evaluator=evaluator, numFolds=3)
lrModel = lr_cv.fit(your_training_set_here)
bestModel = lrModel.bestModel

打印出所需的参数:
>>> print 'Best Param (regParam): ', bestModel._java_obj.getRegParam()
0
>>> print 'Best Param (MaxIter): ', bestModel._java_obj.getMaxIter()
500
>>> print 'Best Param (elasticNetParam): ', bestModel._java_obj.getElasticNetParam()
1

这也适用于其他方法,例如 extractParamMap()。他们应该尽快解决此问题。

关于pyspark - 如何在PySpark中从spark.ml提取模型超参数?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/36697304/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com