gpt4 book ai didi

Keras LSTM 状态

转载 作者:行者123 更新时间:2023-12-04 19:33:35 26 4
gpt4 key购买 nike

我想在 Keras 中运行 LSTM 并获得输出和状态。在 TF 中有这样的事情

with tf.variable_scope("RNN"):
for time_step in range(num_steps):
if time_step > 0: tf.get_variable_scope().reuse_variables()
(cell_output, state) = cell(inputs[:, time_step, :], state)
outputs.append(cell_output)

有没有办法在 Keras 中做到这一点,当序列的长度很大时,我可以获得最后一个状态并将其提供给新的输入。我知道 stateful=True 但我也想在训练时访问状态。我知道它使用的是 scan 而不是 for 循环,但基本上我想保存状态,然后在我下次运行时,将它们设为 LSTM 的起始状态。简而言之,获取输出和状态。

最佳答案

由于LSTM是一层,一层在keras中只能有一个输出(如果我错了请纠正我),你不能在不修改源代码的情况下同时获得两个输出。

最近我正在 hacking keras 来实现一些先进的结构,一些你可能不喜欢的想法确实有效。我正在做的是覆盖 keras 层 这样我们就可以访问表示隐藏状态的张量。

首先,您可以查看 call()函数在 keras/layers/recurrent.py 关于 keras 如何完成这项工作:

def call(self, x, mask=None):
# input shape: (nb_samples, time (padded with zeros), input_dim)
# note that the .build() method of subclasses MUST define
# self.input_spec with a complete input shape.
input_shape = self.input_spec[0].shape
if K._BACKEND == 'tensorflow':
if not input_shape[1]:
raise Exception('When using TensorFlow, you should define '
'explicitly the number of timesteps of '
'your sequences.\n'
'If your first layer is an Embedding, '
'make sure to pass it an "input_length" '
'argument. Otherwise, make sure '
'the first layer has '
'an "input_shape" or "batch_input_shape" '
'argument, including the time axis. '
'Found input shape at layer ' + self.name +
': ' + str(input_shape))
if self.stateful:
initial_states = self.states
else:
initial_states = self.get_initial_states(x)
constants = self.get_constants(x)
preprocessed_input = self.preprocess_input(x)

last_output, outputs, states = K.rnn(self.step, preprocessed_input,
initial_states,
go_backwards=self.go_backwards,
mask=mask,
constants=constants,
unroll=self.unroll,
input_length=input_shape[1])
if self.stateful:
self.updates = []
for i in range(len(states)):
self.updates.append((self.states[i], states[i]))

if self.return_sequences:
return outputs
else:
return last_output

其次,我们应该覆盖我们的图层,这是一个简单的脚本:

import keras.backend as K
from keras.layers import Input, LSTM
class MyLSTM(LSTM):
def call(self, x, mask=None):
# .... blablabla, right before return

# we add this line to get access to states
self.extra_output = states

if self.return_sequences:
# .... blablabla, to the end

# you should copy **exactly the same code** from keras.layers.recurrent

I = Input(shape=(...))
lstm = MyLSTM(20)
output = lstm(I) # by calling, we actually call the `call()` and create `lstm.extra_output`
extra_output = lstm.extra_output # refer to the target

calculate_function = K.function(inputs=[I], outputs=extra_output+[output]) # use function to calculate them **simultaneously**.

关于Keras LSTM 状态,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/38404085/

26 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com