gpt4 book ai didi

python - Tensorflow 估计器 : Switching to careful_interpolation to get the correct PR-AUC of a model

转载 作者:太空宇宙 更新时间:2023-11-03 13:08:20 25 4
gpt4 key购买 nike

在我的项目中,我使用预制的估计器 DNNClassifier。这是我的估算器:

model = tf.estimator.DNNClassifier(
hidden_units=network,
feature_columns=feature_cols,
n_classes= 2,
activation_fn=tf.nn.relu,
optimizer=tf.train.ProximalAdagradOptimizer(
learning_rate=0.1,
l1_regularization_strength=0.001
),
config=chk_point_run_config,
model_dir=MODEL_CHECKPOINT_DIR
)

当我使用 eval_res = model.evaluate(..) 评估模型时,我收到以下警告:

WARNING:tensorflow:Trapezoidal rule is known to produce incorrect PR-AUCs; please switch to "careful_interpolation" instead.

如何切换到 careful_interpolation 以从 evaluate() 方法获得正确的结果?

Tensorflow 版本:1.8

最佳答案

不幸的是,使用预制估算器几乎没有自由定制评估过程。目前,DNNClassifier 似乎没有提供调整评估指标的方法,对于其他估计器也是如此。

尽管不理想,但一种解决方案是使用 tf.contrib.metrics.add_metrics 使用所需指标来增强估算器,如果将完全相同的 key 分配给新指标,它将替换旧指标:

If there is a name conflict between this and estimators existing metrics, this will override the existing one.

它具有适用于任何生成概率预测的估算器的优势,但代价是仍然为每个评估计算覆盖的指标。 DNNClassifier 估计器在键 'logistic' 下提供逻辑值(介于 0 和 1 之间)(固定估计器中可能的键列表是 here )。对于其他估计器头来说,情况可能并非总是如此,但可能有替代方案:在使用 tf.contrib.estimator.multi_label_head 构建的多标签分类器中, logistic 不可用,但可以使用 probabilities

因此,代码将如下所示:

def metric_auc(labels, predictions):
return {
'auc_precision_recall': tf.metrics.auc(
labels=labels, predictions=predictions['logistic'], num_thresholds=200,
curve='PR', summation_method='careful_interpolation')
}

estimator = tf.estimator.DNNClassifier(...)
estimator = tf.contrib.estimator.add_metrics(estimator, metric_auc)

评估时,警告信息仍然会出现,但经过仔细插值的AUC会在不久后被调用。将此指标分配给不同的键还可以让您检查两种求和方法之间的差异。我对多标签逻辑回归任务的测试表明,测量结果可能确实略有不同:auc_precision_recall = 0.05173396,auc_precision_recall_careful = 0.05059402。


尽管 documentation,默认求和方法仍然是 'trapezoidal' 也是有原因的表明仔细插值是“严格首选”。作为commented in pull request #19079 , 更改将显着向后不兼容。对同一拉取请求的后续评论提出了上述解决方法。

关于python - Tensorflow 估计器 : Switching to careful_interpolation to get the correct PR-AUC of a model,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/50850258/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com