gpt4 book ai didi

python - Keras:one-hot编码的类权重(class_weight)

转载 作者:太空狗 更新时间:2023-10-29 16:54:11 25 4
gpt4 key购买 nike

我想在 keras model.fit 中使用 class_weight 参数来处理不平衡的训练数据。通过查看一些文档,我了解到我们可以像这样传递一个字典:

class_weight = {0 : 1,
1: 1,
2: 5}

(在本例中,class-2 将在损失函数中得到更高的惩罚。)

问题是我的网络的输出具有单热编码,即 class-0 = (1, 0, 0),class-1 = (0, 1, 0),class-3 = (0, 0, 1).

我们如何使用 class_weight 进行单热编码输出?

通过查看 some codes in Keras ,看起来 _feed_output_names 包含输出类列表,但在我的例子中,model.output_names/model._feed_output_names 返回 [ 'dense_1']

相关:How to set class weights for imbalanced classes in Keras?

最佳答案

这是一个更短、更快的解决方案。如果你的 one-hot 编码的 y 是一个 np.array:

import numpy as np
from sklearn.utils.class_weight import compute_class_weight

y_integers = np.argmax(y, axis=1)
class_weights = compute_class_weight('balanced', np.unique(y_integers), y_integers)
d_class_weights = dict(enumerate(class_weights))

d_class_weights 然后可以传递给 .fit 中的 class_weight

关于python - Keras:one-hot编码的类权重(class_weight),我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/43481490/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com