gpt4 book ai didi

python - 多热标签编码

转载 作者:太空宇宙 更新时间:2023-11-03 21:14:31 24 4
gpt4 key购买 nike

我是 Tensorflow 新手。我有一个图像数据集,其中一张图像有多个标签。据我了解,我需要使用 tf.losses.sigmoid_cross_entropy() 。我尝试将 tf.one_hot 应用于标签,但是当我尝试将它们传递到损失函数时,出现错误,形状不兼容。我该如何解决这个问题?

最佳答案

您对tf.losses.sigmoid_cross_entropy的看法是正确的。您所需要做的就是用 tf.reduce_max 包装 tf.one_hot 来降低维度。

tf.reduce_max(tf.one_hot(labels, num_classes, dtype=tf.int32), axis=0)

这应该返回形状为 (num_classes,) 的张量,这正是损失函数所需的。

关于python - 多热标签编码,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/54816225/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com