gpt4 book ai didi

tensorflow - tf.nn.in_top_k : targets out of range

转载 作者:行者123 更新时间:2023-12-03 21:21:35 27 4
gpt4 key购买 nike

我从 tensorflow 中调整了 cifar10 网络,以解决我自己的分类问题。我已经训练了网络,现在我尝试使用 cifar10_eval.py 评估训练后的模型

top_k_op = tf.nn.in_top_k(logits, labels, 1)

但我收到以下错误。经过进一步调查,目标指数在2,3和4之间变化
tensorflow.python.framework.errors.InvalidArgumentError: targets[3] is out of range

到现在为止,我知道我的标签 Tensor 有问题。它是一个 int32-Tensor,其 shape(50,) 如下所示。
labels = {Tensor} Tensor("batch_processing/Reshape_1:0", shape=(50,), dtype=int32, device=/device:CPU:0)

我的数据集只有 2 个类/标签。也许这可能是问题所在。有谁知道,问题是什么?

最佳答案

总结一下,函数tf.nn.in_top_k(predictions, targets, k) (参见 doc )有参数:

  • 预测:形状[batch_size, num_classes] , 输入 float32
  • 目标(正确的标签):形状 [batch_size] , 输入 int32 或 int64


  • 该函数引发错误 InvalidArgumentError: targets[i] is out of range当元素 targets[i]超出范围 predictions[i] .

    例如,有 2 个类( num_classes=2 )和 targets=[1, 3] .
    使用这些目标,您将看到错误 InvalidArgumentError: targets[1] is out of range因为 targets[1] = 3超出 predictions[1] 的范围它只有形状 2。

    检查您的 labels是正确的,您可以打印其中的最大值:

    labels = ...
    labels_max = tf.reduce_max(labels)

    sess = tf.Session()
    print sess.run(labels_max)

    如果打印的值优于 num_classes , 你有个问题。

    关于tensorflow - tf.nn.in_top_k : targets out of range,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/37587622/

    27 4 0
    Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
    广告合作:1813099741@qq.com 6ren.com