gpt4 book ai didi

python - 将先验信念添加到神经网络中

转载 作者:行者123 更新时间:2023-11-30 08:56:37 27 4
gpt4 key购买 nike

我正忙于解决一个分类问题,分为三个类别。其中一类从未被预测/分类。我想知道是否有办法将先验信念注入(inject)我的神经网络中,无论是设计与否。

我的足球预测模型预测[平局、主队获胜、客队获胜]。我的类(class)非常平衡(40%、30%、30%)。占数据 40% 的类 [Draw] 是我的神经网络从未预测到的。我的数据集包含 1900 个样本。

我使用的是具有 2 到 4 个隐藏层的深度神经网络。

我的最佳模型的代码(基于训练/验证损失)如下:

X_all = df.copy()

train_cols = ['a_line0','a_line1','a_line2','a_line3','a_line4','a_line5',
'a_line6','a_line7','a_line8','a_line9','a_line10','h_line0',
'h_line1','h_line2','h_line3','h_line4','h_line5','h_line6',
'h_line7','h_line8','h_line9','h_line10','odds0','odds1','odds2']


x = X_all[train_cols]

x_v = x.values #returns a numpy array
min_max_scaler = preprocessing.MinMaxScaler()
x_scaled = min_max_scaler.fit_transform(x_v)
x = pd.DataFrame(x_scaled)

y = X_all['result']
ohe = OneHotEncoder(n_values=3,categories='auto')
y = ohe.fit_transform(y.reshape(-1,1))

X_train, X_test, y_train, y_test = train_test_split(x, y, test_size=0.3, random_state=0)

for lr,ep in [(0.001,300)]:
model = Sequential()
model.add(Dense(25, input_dim=25, activation='relu'))
model.add(Dense(36, activation='relu'))
model.add(Dropout(0.2))
model.add(Dense(12, activation='relu'))
model.add(Dense(3, activation='sigmoid'))
adam = kr.optimizers.Adam(lr=lr, decay=1e-6)
model.compile(loss='categorical_crossentropy', optimizer=adam, metrics=['accuracy'])
model.fit(X_train, y_train, epochs=ep, batch_size=10,verbose = 0)
_, accuracy = model.evaluate(X_test, y_test)
_, accuracy1 = model.evaluate(X_train, y_train)
print('Testing Accuracy: %.2f' % (accuracy*100),'Train Accuracy: %.2f' % (accuracy1*100), 'learning rate : ', lr)

如果代码有点困惑,我深表歉意。我的模型在我的网络配置上也过度拟合了 +- 16%(52% vs 68%)。

最佳答案

由于您处于多类单标签设置(即您的标签是互斥的),因此您不应在最后一层中使用sigmoid作为激活;将其更改为

model.add(Dense(3, activation='softmax'))

此外,默认情况下不应使用 dropout;首先删除它,只有在改善结果时才添加它。

关于python - 将先验信念添加到神经网络中,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/57871767/

27 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com