gpt4 book ai didi

python - 多标签分类: How to learn threshold values?

转载 作者:行者123 更新时间:2023-11-30 22:36:14 26 4
gpt4 key购买 nike

我有一个深度 CNN,可以很好地进行多类分类。我想“升级”挑战并针对多标签分类问题对其进行训练。

为此,我用 sigmoid 替换了 softmax,并尝试训练我的网络以最小化:

tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=y_, logits=y_pred)

但我最终得到了奇怪的预测:

Prediction for Im1 : [ 0.59275776  0.08751075  0.37567005  0.1636796   0.42361438  0.08701646

0.38991812 0.54468459 0.34593087 0.82790571]

Prediction for Im1 : [ 0.52609032  0.07885984  0.45780018  0.04995904  0.32828355  0.07349177

0.35400775 0.36479294 0.30002621 0.84438241]

Prediction for Im1 : [ 0.58714485  0.03258472  0.3349618   0.03199361  0.54665488  0.02271551

0.43719986 0.54638696 0.20344526 0.88144571]

因此,我想尝试为每个类别设置网络学习阈值,以确定样本是否属于该类别。

所以我将其添加到我的代码中:

initial = tf.truncated_normal([numberOfClasses], stddev=0.1)
W_thresh = tf.Variable(initial)

y_predict_thresh = int(y_predict > W_thresh)

但是我有一个错误:

TypeError: int() argument must be a string or a number, not 'Tensor'.

任何人有任何想法可以帮助我前进(如何避免这个错误?,我的数据集真的不平衡这一事实会导致这些“恒定”预测吗?关于多标签分类的其他建议?,...) ?

谢谢

编辑:我刚刚意识到,对于反向传播来说,进行阈值处理可能并不是很酷:/

最佳答案

不知道你是否还需要它,但是你可以使用tensorflow转换函数tf.to_int32、tf.to_int64。在评估之前,您的表达式是 python 的对象,因此它不能简单地将其转换为 int()

这可以满足您的需要:

with tf.Session() as sess:
check = sess.run([tf.to_int64(W1 > W2)])

关于python - 多标签分类: How to learn threshold values?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/44265934/

26 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com