gpt4 book ai didi

python - Tensorflow 2.0 - 这些模型预测代表概率吗?

转载 作者:行者123 更新时间:2023-12-01 00:16:17 25 4
gpt4 key购买 nike

我有一个非常简单的 Tensorflow 2 Keras 模型,可以对某些数据进行惩罚逻辑回归。我希望获得每个类别的概率,而不仅仅是 [0 或 1] 的预测值。

我想我得到了我想要的,但只是想确保这些数字是我认为的那样。我使用了 Tensorflow.keras 中的 model.predict_on_batch() 函数,但文档只是说这提供了一个 预测numpy 数组。不过,我相信我得到了概率,但我希望有人能够证实。

模型代码如下所示:

feature_layer = tf.keras.layers.DenseFeatures(features)                                                                    

model = tf.keras.Sequential([
feature_layer,
layers.Dense(1, activation='sigmoid', kernel_regularizer=tf.keras.regularizers.l1(0.01))
])
model.compile(optimizer='adam',
loss='binary_crossentropy',
metrics=['accuracy'])

predictions = model.predict_on_batch(validation_dataset)

print('Predictions for a single batch.')
print(predictions)

所以我得到的预测如下:

Predictions for a single batch.                                                                             
tf.Tensor(
[[0.10916319]
[0.14546806]
[0.13057315]
[0.11713684]
[0.16197902]
[0.19613355]
[0.1388464 ]
[0.14122346]
[0.26149303]
[0.12516734]
[0.1388464 ]
[0.14595506]
[0.14595506]]

现在,逻辑回归中的预测将是 0 或 1 的数组。但是因为我正在获取浮点值。然而,当实际上存在示例为 0 的概率和示例为 1 的概率时,我只是得到一个值。因此,我可以想象每行或示例都有一个包含 2 个概率的数组。当然,概率(Y = 0)+概率(Y = 1)= 1,所以这可能只是一些简洁的表示。

那么,下面数组中的值是否表示示例或 Y = 1 或其他值的概率?

最佳答案

此处表示的值:

tf.Tensor(                                                                                       
[[0.10916319]
[0.14546806]
[0.13057315]
[0.11713684]
[0.16197902]
[0.19613355]
[0.1388464 ]
[0.14122346]
[0.26149303]
[0.12516734]
[0.1388464 ]
[0.14595506]
[0.14595506]]
  1. 概率是与您的每一类相对应的。

  2. 由于您在最后一层使用了 sigmoid 激活,这些将 范围为 [0, 1]。

  3. 您的模型非常浅(层数很少),因此这些预测概率在类之间非常接近。我建议你添加更多层。

结论

为了回答您的问题,这些是概率,但仅取决于您的激活函数选择(sigmoid)。如果您使用 tanh 激活,这些值将在 [-1,1] 范围内。

请注意,由于使用了 binary_crossentropy 损失,因此每个类别的这些概率都是“二元”的 - 即 10.92% 的类别 1 存在,89.08% 的类别 1 不存在,对于其他类别,依此类推类。如果您希望预测遵循概率规则 (sum = 1),那么您应该考虑categorical_crossentropy

关于python - Tensorflow 2.0 - 这些模型预测代表概率吗?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/59314746/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com