gpt4 book ai didi

python - Keras train_on_batch() 与 fit() 相比不训练模型

转载 作者:行者123 更新时间:2023-12-01 08:16:22 30 4
gpt4 key购买 nike

我的数据集太大,RAM 无法容纳,因此我选择使用 train_on_batch 增量训练我的模型。为了测试这种方法是否有效,我使用了大数据的子集来运行一些初步测试。

但是,我在训练模型时遇到了一些问题,即使用 train_on_batch() 训练时模型的准确率停留在 10%。使用 fit(),我在 40 个 epoch 时获得了 95% 的准确率。我也尝试过 fit_generator() 并遇到了类似的问题。

使用 fit()

results = model.fit(x_train,y_train,batch_size=128,nb_epoch=40)

使用train_on_batch()

#386 has been chosen so that each batch size is 128
splitSize = len(y_train) // 386

for j in range(20):
print('epoch: '+str(j)+' ----------------------------')
np.random.shuffle(x_train)
np.random.shuffle(y_train)
xb = np.array_split(x_train,386)
yb = np.array_split(y_train,386)
sumAcc = 0
index = list(range(386))
random.shuffle(index)
for i in index:
results = model.train_on_batch(xb[i],yb[i])
sumAcc += results[1]
print(sumAcc/(386))

最佳答案

您使用的shuffle不正确,因为shuffle后的y_train与x_train不匹配。当您像这样进行洗牌时,每个数组都会以不同的顺序进行洗牌。您可以使用:

length = x_train.shape[0]
idxs = np.arange(0, length)
np.random.shuffle(idxs)

x_train = x_train[idxs]
y_train = y_train[idxs]

关于python - Keras train_on_batch() 与 fit() 相比不训练模型,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/54963569/

30 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com