gpt4 book ai didi

python - 当 state_is_tuple=True 时如何设置 TensorFlow RNN 状态?

转载 作者:太空狗 更新时间:2023-10-29 17:04:06 26 4
gpt4 key购买 nike

我写了一个RNN language model using TensorFlow .该模型作为 RNN 类实现。图结构在构造函数中构建,而 RNN.trainRNN.test 方法运行它。

当我移动到训练集中的新文档时,或者当我想在训练期间运行验证集时,我希望能够重置 RNN 状态。我通过管理训练循环内的状态,通过提要字典将其传递到图中来做到这一点。

在构造函数中,我这样定义 RNN

    cell = tf.nn.rnn_cell.LSTMCell(hidden_units)
rnn_layers = tf.nn.rnn_cell.MultiRNNCell([cell] * layers)
self.reset_state = rnn_layers.zero_state(batch_size, dtype=tf.float32)
self.state = tf.placeholder(tf.float32, self.reset_state.get_shape(), "state")
self.outputs, self.next_state = tf.nn.dynamic_rnn(rnn_layers, self.embedded_input, time_major=True,
initial_state=self.state)

训练循环看起来像这样

 for document in document:
state = session.run(self.reset_state)
for x, y in document:
_, state = session.run([self.train_step, self.next_state],
feed_dict={self.x:x, self.y:y, self.state:state})

xy 是文档中的批量训练数据。这个想法是我在每批之后传递最新状态,除非我开始一个新文档,当我通过运行 self.reset_state 将状态清零时。

这一切都有效。现在我想更改我的 RNN 以使用推荐的 state_is_tuple=True。但是,我不知道如何通过提要字典传递更复杂的 LSTM 状态对象。我也不知道要将什么参数传递给构造函数中的 self.state = tf.placeholder(...) 行。

这里正确的策略是什么? dynamic_rnn 的示例代码或文档仍然不多。


TensorFlow 问题 26952838显得相关。

A blog post on WILDML 解决了这些问题,但没有直接给出答案。

另见 TensorFlow: Remember LSTM state for next batch (stateful LSTM) .

最佳答案

Tensorflow 占位符的一个问题是您只能使用 Python 列表或 Numpy 数组(我认为)来提供它。所以你不能在 LSTMStateTuple 的元组中保存运行之间的状态。

我通过将状态保存在这样的张量中解决了这个问题

initial_state = np.zeros((num_layers, 2, batch_size, state_size))

LSTM 层中有两个组件,单元状态隐藏状态,这就是“2”的来源。 (这篇文章很棒:https://arxiv.org/pdf/1506.00019.pdf)

在构建图表时,您会像这样解压并创建元组状态:

state_placeholder = tf.placeholder(tf.float32, [num_layers, 2, batch_size, state_size])
l = tf.unpack(state_placeholder, axis=0)
rnn_tuple_state = tuple(
[tf.nn.rnn_cell.LSTMStateTuple(l[idx][0],l[idx][1])
for idx in range(num_layers)]
)

然后你以通常的方式获得新状态

cell = tf.nn.rnn_cell.LSTMCell(state_size, state_is_tuple=True)
cell = tf.nn.rnn_cell.MultiRNNCell([cell] * num_layers, state_is_tuple=True)

outputs, state = tf.nn.dynamic_rnn(cell, series_batch_input, initial_state=rnn_tuple_state)

不应该是这样的……也许他们正在研究解决方案。

关于python - 当 state_is_tuple=True 时如何设置 TensorFlow RNN 状态?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/39112622/

26 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com