gpt4 book ai didi

python - 增加一个类别

转载 作者:太空狗 更新时间:2023-10-30 01:04:56 25 4
gpt4 key购买 nike

我构建了一个 TensorFlow 模型,该模型使用 DNNClassifier 将输入分为两类。

我的问题是结果 1 在 90-95% 的时间里都会发生。因此,TensorFlow 为我的所有预测提供了相同的概率。

我正在尝试预测其他结果(例如,结果 2 的误报比错过结果 2 的可能发生更可取)。我知道在一般的机器学习中,在这种情况下,尝试提高结果 2 的权重是值得的。

但是,我不知道如何在 TensorFlow 中执行此操作。 documentation暗示它是可能的,但我找不到它实际是什么样子的任何例子。有没有人成功做到这一点,或者有谁知道我在哪里可以找到一些示例代码或详尽的解释(我使用的是 Python)?

注意:当有人使用 TensorFlow 的更基本部分而不是估算器时,我已经看到暴露的权重被操纵。出于维护原因,我需要使用估算器来执行此操作。

最佳答案

tf.estimator.DNNClassifier构造函数有 weight_column 参数:

weight_column: A string or a _NumericColumn created by tf.feature_column.numeric_column defining feature column representing weights. It is used to down weight or boost examples during training. It will be multiplied by the loss of the example. If it is a string, it is used as a key to fetch weight tensor from the features. If it is a _NumericColumn, raw tensor is fetched by key weight_column.key, then weight_column.normalizer_fn is applied on it to get weight tensor.

所以只需添加一个新列并为稀有类填充一些权重:

weight = tf.feature_column.numeric_column('weight')
...
tf.estimator.DNNClassifier(..., weight_column=weight)

[更新]这是一个完整的工作示例:

import numpy as np
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data

mnist = input_data.read_data_sets('mnist', one_hot=False)
train_x, train_y = mnist.train.next_batch(1024)
test_x, test_y = mnist.test.images, mnist.test.labels

x_column = tf.feature_column.numeric_column('x', shape=[784])
weight_column = tf.feature_column.numeric_column('weight')
classifier = tf.estimator.DNNClassifier(feature_columns=[x_column],
hidden_units=[100, 100],
weight_column=weight_column,
n_classes=10)

# Training
train_input_fn = tf.estimator.inputs.numpy_input_fn(x={'x': train_x, 'weight': np.ones(train_x.shape[0])},
y=train_y.astype(np.int32),
num_epochs=None, shuffle=True)
classifier.train(input_fn=train_input_fn, steps=1000)

# Testing
test_input_fn = tf.estimator.inputs.numpy_input_fn(x={'x': test_x, 'weight': np.ones(test_x.shape[0])},
y=test_y.astype(np.int32),
num_epochs=1, shuffle=False)
acc = classifier.evaluate(input_fn=test_input_fn)
print('Test Accuracy: %.3f' % acc['accuracy'])

关于python - 增加一个类别,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/48098951/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com