gpt4 book ai didi

python - 文本预测 LSTM 神经网络的问题

转载 作者:行者123 更新时间:2023-11-30 09:40:51 24 4
gpt4 key购买 nike

我正在尝试使用递归神经网络(LSTM)和书籍数据集进行文本预测。无论我如何尝试更改层大小或其他参数,它总是会过度拟合。

我一直在尝试更改层数、LSTM 层中的单元数、正则化、归一化、batch_size、洗牌训练数据/验证数据、将数据集更改为更大。现在我尝试使用 ~140kb txt 书。我也尝试过 200kb、1mb、5mb。

创建训练/验证数据:

sequence_length = 30

x_data = []
y_data = []

for i in range(0, len(text) - sequence_length, 1):
x_sequence = text[i:i + sequence_length]
y_label = text[i + sequence_length]

x_data.append([char2idx[char] for char in x_sequence])
y_data.append(char2idx[y_label])

X = np.reshape(x_data, (data_length, sequence_length, 1))
X = X/float(vocab_length)
y = np_utils.to_categorical(y_data)

# Split into training and testing set, shuffle data
X_train, X_test, y_train, y_test = train_test_split(X,y, test_size=0.2, shuffle=False)

# Shuffle testing set
X_test, y_test = shuffle(X_test, y_test, random_state=0)

创建模型:

model = Sequential()
model.add(LSTM(256, input_shape=(X.shape[1], X.shape[2]), return_sequences=True, recurrent_initializer='glorot_uniform', recurrent_dropout=0.3))
model.add(LSTM(256, return_sequences=True, recurrent_initializer='glorot_uniform', recurrent_dropout=0.3))
model.add(LSTM(256, recurrent_initializer='glorot_uniform', recurrent_dropout=0.3))
model.add(Dropout(0.2))
model.add(Dense(y.shape[1], activation='softmax'))

enter image description here编译模型:

model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])

我得到以下特征: enter image description here

我不知道该怎么处理这种过度拟合,因为我正在互联网上搜索,尝试了很多方法,但似乎都不起作用。

如何才能获得更好的结果?这些预测目前看来不太好。

最佳答案

以下是我接下来要尝试的一些事情。(我也是业余爱好者,如有错误请指正)

  1. 尝试提取vector representation从文字来看。尝试 word2vec、GloVe、FastText、ELMo。提取向量表示,然后将它们输入网络。您还可以创建 embedding layer来帮助解决这个问题。这个blog有更多信息。
  2. 256 个循环单元可能太多了。我认为永远不应该从一个庞大的网络开始。从小处开始。看看你是否适配不足。如果是,那就更大。
  3. 关闭优化器。我发现 Adam 倾向于过度拟合。我在 rmsprop 和 Adadelta 方面取得了更好的成功。
  4. 也许,attention is all you need? Transformers 最近为 NLP 做出了巨大贡献。也许你可以尝试implementing simple soft attention mechanism在你的网络中。这是nice video series如果您还不熟悉。安interactive research paper就在上面。
  5. CNN 也是 pretty dope在 NLP 应用中。尽管直观上它们对于文本数据没有任何意义(对大多数人来说)。也许你可以尝试利用它们、堆叠它们等。尝试一下。这是guide关于如何使用它进行句子分类。我知道,您的域不同。但我认为直觉会延续下去。 :)

关于python - 文本预测 LSTM 神经网络的问题,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/58764687/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com