gpt4 book ai didi

python - Tensorflow:如何使用 dynamic_rnn 从 LSTMCell 获取中间单元状态 (c)?

转载 作者:太空宇宙 更新时间:2023-11-04 02:36:37 26 4
gpt4 key购买 nike

默认情况下,函数 dynamic_rnn 只输出每个时间点的隐藏状态(称为 m),可以通过以下方式获得:

cell = tf.contrib.rnn.LSTMCell(100)
rnn_outputs, _ = tf.nn.dynamic_rnn(cell,
inputs=inputs,
sequence_length=sequence_lengths,
dtype=tf.float32)

有没有办法另外获得中间(非最终)细胞状态(c)?

tensorflow 贡献者 mentions它可以用细胞包装来完成:

class Wrapper(tf.nn.rnn_cell.RNNCell):
def __init__(self, inner_cell):
super(Wrapper, self).__init__()
self._inner_cell = inner_cell
@property
def state_size(self):
return self._inner_cell.state_size
@property
def output_size(self):
return (self._inner_cell.state_size, self._inner_cell.output_size)
def call(self, input, state)
output, next_state = self._inner_cell(input, state)
emit_output = (next_state, output)
return emit_output, next_state

不过,好像不行。有什么想法吗?

最佳答案

建议的解决方案适用于我,但 Layer.call 方法规范更通用,因此以下 Wrapper 应该对 API 更改更稳健。你这个:

class Wrapper(tf.nn.rnn_cell.RNNCell):
def __init__(self, inner_cell):
super(Wrapper, self).__init__()
self._inner_cell = inner_cell

@property
def state_size(self):
return self._inner_cell.state_size

@property
def output_size(self):
return (self._inner_cell.state_size, self._inner_cell.output_size)

def call(self, input, *args, **kwargs):
output, next_state = self._inner_cell(input, *args, **kwargs)
emit_output = (next_state, output)
return emit_output, next_state

这是测试:

n_steps = 2
n_inputs = 3
n_neurons = 5

X = tf.placeholder(dtype=tf.float32, shape=[None, n_steps, n_inputs])
basic_cell = Wrapper(tf.nn.rnn_cell.LSTMCell(num_units=n_neurons, state_is_tuple=False))
outputs, states = tf.nn.dynamic_rnn(basic_cell, X, dtype=tf.float32)
print(outputs, states)

X_batch = np.array([
# t = 0 t = 1
[[0, 1, 2], [9, 8, 7]], # instance 0
[[3, 4, 5], [0, 0, 0]], # instance 1
[[6, 7, 8], [6, 5, 4]], # instance 2
[[9, 0, 1], [3, 2, 1]], # instance 3
])

with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
outputs_val = outputs[0].eval(feed_dict={X: X_batch})
print(outputs_val)

返回的outputs(?, 2, 10)(?, 2, 5)张量的元组,都是LSTM 状态和输出。请注意,我使用的是 LSTMCell 的“毕业”版本,来自 tf.nn.rnn_cell 包,而不是 tf.contrib.rnn。另请注意 state_is_tuple=True 以避免处理 LSTMStateTuple

关于python - Tensorflow:如何使用 dynamic_rnn 从 LSTMCell 获取中间单元状态 (c)?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/47745027/

26 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com