gpt4 book ai didi

python - Keras - 恢复特定时间戳的 LSTM 隐藏状态

转载 作者:行者123 更新时间:2023-12-01 01:11:18 26 4
gpt4 key购买 nike

这个问题继续(LSTM - Making predictions on partial sequence)。正如上一个问题中所述,我已经训练了一个有状态 LSTM 模型,用于批量分类 100 个样本/标签,如下所示:

[Feature 1,Feature 2, .... ,Feature 3][Label 1]
[Feature 1,Feature 2, .... ,Feature 3][Label 2]
...
[Feature 1,Feature 2, .... ,Feature 3][Label 100]

型号代码:

def build_model(num_samples, num_features, is_training):
model = Sequential()
opt = optimizers.Adam(lr=0.0005, beta_1=0.9, beta_2=0.999, epsilon=1e-08, decay=0.0001)

batch_size = None if is_training else 1
stateful = False if is_training else True
first_lstm = LSTM(32, batch_input_shape=(batch_size, num_samples, num_features), return_sequences=True,
activation='tanh', stateful=stateful)

model.add(first_lstm)
model.add(LeakyReLU())
model.add(Dropout(0.2))
model.add(LSTM(16, return_sequences=True, activation='tanh', stateful=stateful))
model.add(Dropout(0.2))
model.add(LeakyReLU())
model.add(LSTM(8, return_sequences=True, activation='tanh', stateful=stateful))
model.add(LeakyReLU())
model.add(Dense(1, activation='sigmoid'))

if is_training:
model.compile(loss='binary_crossentropy', optimizer=opt,
metrics=['accuracy', f1])
return model

预测时,模型是无状态,批量大小为1,并且在每个样本之后检索分类概率,如下所示:

[Feature 1,Feature 2, .... ,Feature 10][Label 1] -> (model) -> probability

在模型处理完一批 100 个样本后调用 model.reset_states()。该模型有效并且结果非常好。

注意:我的数据是来自多个来源的事件。

<小时/>

我的问题:

当我测试模型时,我可以控制样本的顺序,并且可以确保样本来自同一来源。即所有前 100 个样本都来自源 1,然后在调用 model.reset_states() 后,接下来的 100 个样本来自源 2,依此类推。

但是,在我的生产环境中,示例以异步方式到达,例如:

前 3 个样本来自源 1,然后 2 个样本来自源 2,依此类推”

插图:

enter image description here

<小时/>

我的问题:

如何在每个源的特定时间戳处序列化模型状态,以便我可以在每个样本之后保存它,然后在来自同一源的新样本到达时将其加载回来。

最佳答案

您可以像这样获取和设置内部状态:

import keras.backend as K

def get_states(model):
return [K.get_value(s) for s,_ in model.state_updates]

def set_states(model, states):
for (d,_), s in zip(model.state_updates, states):
K.set_value(d, s)

关于python - Keras - 恢复特定时间戳的 LSTM 隐藏状态,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/54850854/

26 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com