gpt4 book ai didi

python - Tensorflow 加权交叉熵损失函数在 DNN 分类器估计器函数中的位置在哪里?

转载 作者:行者123 更新时间:2023-11-30 09:05:52 25 4
gpt4 key购买 nike

我目前正在使用 tf.estimator.DNNClassifier 开发具有高度倾斜数据(90% 负数/10% 正数)的二项式分类算法。由于我训练的所有模型都会将所有样本标记为负样本,因此我需要实现加权损失函数。

我在这里查看了许多不同的问题,其中许多问题都很有启发性。然而,我无法得到关于如何实际实现这些功能的实用的端到端答案。 Thisthis线程是最好的。

我的问题是:我想使用 tf.nn.weighted_cross_entropy_with_logits(),但我不知道应该将其插入代码中的何处。

我有一个构建特征列的函数:

def construct_feature_columns(input_features):
return set([tf.feature_column.numeric_column(my_feature)
for my_feature in input_features])

定义 tf.estimator.DNNClassifier 以及其他参数(例如优化器和输入函数)的函数:

def train_nn_classifier_model(
learning_rate,
steps,
batch_size,
hidden_units,
training_examples,
training_targets,
validation_examples,
validation_targets):

dnn_classifier = tf.estimator.DNNClassifier(
feature_columns=construct_feature_columns(training_examples),
hidden_units=hidden_units,
optimizer=my_optimizer)

训练函数:

dnn_classifier.train(input_fn=training_input_fn, steps=steps_per_period)

预测函数,用于计算训练时的误差:

training_probabilities = dnn_classifier.predict(input_fn=predict_training_input_fn)

优化器:

  my_optimizer = tf.train.GradientDescentOptimizer(learning_rate=learning_rate)
my_optimizer = tf.contrib.estimator.clip_gradients_by_norm(my_optimizer, 5.0)

输入函数(用于训练输入、预测训练输入和验证输入):

  training_input_fn = lambda: my_input_fn(
training_examples,
training_targets['True/False'],
batch_size=batch_size)

我应该在哪里插入tf.nn.weighted_cross_entropy_with_logits ,所以我的模型使用这个函数计算损失?

此外,如何在交叉熵函数内调用目标(与 logits 类型和形状相同的张量)?它是training_targets DataFrame吗?它是以training_targets作为输入的输入函数的输出吗?

logits 具体是多少?因为对我来说,它们应该是来自函数的预测:

training_probabilities = dnn_classifier.predict(input_fn=predict_training_input_fn)

但这对我来说没有意义。我尝试了很多不同的方法来实现它,但没有一个起作用。

最佳答案

我讨厌传递坏消息,但是 DNN Classifier不支持自定义损失函数:

Loss is calculated by using softmax cross entropy.

这是文档中唯一提到损失(函数)的地方,我找不到任何讨论通过直接更改 DNNClassifier 来解决此问题的帖子。相反,看起来您必须构建自己的 custom Estimator .

关于python - Tensorflow 加权交叉熵损失函数在 DNN 分类器估计器函数中的位置在哪里?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/52156752/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com