gpt4 book ai didi

python - 使用 Keras 的不平衡数据集

转载 作者:太空狗 更新时间:2023-10-30 02:37:57 24 4
gpt4 key购买 nike

我正在使用 Python 和 Keras 库构建一个分类 ANN。我正在使用具有 3 个不同类的不平衡数据集训练 NN。第 1 类的流行率大约是第 2 类和第 3 类的 7.5 倍。作为补救措施,我接受了 this stackoverflow answer 的建议。并设置我的类(class)权重:

class_weight = {0 : 1,
1 : 6.5,
2: 7.5}

但是,问题来了:ANN 以相同的速率预测 3 个类别!

这没有用,因为数据集不平衡的,并且预测每个结果都有 33% 的机会是不准确的。

问题是:如何处理不平衡的数据集,以便 ANN 不会每次都预测第 1 类,而且 ANN 不会以相同的概率预测类别?

这是我正在使用的代码:

class_weight = {0 : 1,
1 : 6.5,
2: 7.5}

# Making the ANN
import keras
from keras.models import Sequential
from keras.layers import Dense
from keras.layers import Dropout

classifier = Sequential()


# Adding the input layer and the first hidden layer with dropout
classifier.add(Dense(activation = 'relu',
input_dim = 5,
units = 3,
kernel_initializer = 'uniform'))
#Randomly drops 0.1, 10% of the neurons in the layer.
classifier.add(Dropout(rate= 0.1))

#Adding the second hidden layer
classifier.add(Dense(activation = 'relu',
units = 3,
kernel_initializer = 'uniform'))
#Randomly drops 0.1, 10% of the neurons in the layer.
classifier.add(Dropout(rate = 0.1))

# Adding the output layer
classifier.add(Dense(activation = 'sigmoid',
units = 2,
kernel_initializer = 'uniform'))

# Compiling the ANN
classifier.compile(optimizer = 'adam',
loss = 'binary_crossentropy',
metrics = ['accuracy'])

# Fitting the ANN to the training set
classifier.fit(X_train, y_train, batch_size = 100, epochs = 100, class_weight = class_weight)

最佳答案

我在您的模型中看到的最明显的问题是它的结构不适合分类。如果您的样本一次只能属于一个类,那么您不应该通过将 sigmoid 激活作为最后一层来忽略这一事实。

理想情况下,分类器的最后一层应该输出样本属于某个类别的概率,即(在您的情况下)数组 [a, b, c] 其中 a + b + c == 1.

如果您使用 sigmoid 输出,则输出 [1, 1, 1] 是可能的,尽管这不是您想要的。这也是您的模型未正确泛化的原因:假设您没有专门训练它更喜欢“不平衡”输出(如 [1, 0, 0]),它将默认为预测它在训练期间看到的平均值,并考虑重新加权。

尝试将最后一层的激活更改为 'softmax' 并将损失更改为 'catergorical_crossentropy':

# Adding the output layer
classifier.add(Dense(activation='softmax',
units=2,
kernel_initializer='uniform'))

# Compiling the ANN
classifier.compile(optimizer='adam',
loss='categorical_crossentropy',
metrics=['accuracy'])

如果这不起作用,请查看我的其他评论并返回给我该信息,但我非常有信心这是主要问题。
干杯

关于python - 使用 Keras 的不平衡数据集,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/48547931/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com