gpt4 book ai didi

python - 生成预测列表

转载 作者:太空宇宙 更新时间:2023-11-03 11:39:15 25 4
gpt4 key购买 nike

使用 Tensorflow,是否有输出网络预测的方法?

我的输出一直在为 12 个类使用 One Hot Representation

[1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0]
[0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0]
[0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0]
etc...

当尝试从我的模型中获取给定输入的预测时,我运行了以下代码

    prediction=tf.argmax(y,1)
best = sess.run([prediction],feed_dict={x: batch_x, y: batch_y,
seqlen: batch_seqlen})
print("Prediction: ")
print(best)

当我运行这段代码并打印预测时,我的输出是:

[array([1, 5, 7, 7, 7, 4, 7, 9, 4, 4, 9, 6, 7, 8, 3, 2], dtype=int64)]

我输入的批量大小是 16,所以有 16 个输出确实有意义。然而,这些都不是 OneHot 表示(不确定 tensorflow 的输出是否打算作为索引进行交互,所以 1 实际上是某种形式的 onehot

有没有一种方法可以让每个特定的 X 创建一个预测排名列表,模型在给定 X 的情况下最有可能找到什么?

这有意义吗?

最佳答案

您正在获取 1-hot 向量的 tf.argmax,所以这就是您看到索引而不是您预期的 1-hot 向量的原因。

要获得分类预测的排序列表,您可以获取预测向量并应用 values, indices = tf.nn.top_k(prediction) values 将是您的预测按降序排列,indices 将是那些排序的 values 索引。

关于python - 生成预测列表,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/53180481/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com