gpt4 book ai didi

python - 在 tensorflow 0.12 中使用 MultiRNNCell

转载 作者:太空宇宙 更新时间:2023-11-04 02:59:12 24 4
gpt4 key购买 nike

在 Tensorflow 0.12 中,MultiRNNCell 的工作方式发生了变化,对于初学者,state_is_tuple 现在默认设置为 True ,此外,还有关于它的讨论:

state_is_tuple: If True, accepted and returned states are n-tuples, where n = len(cells). If False, the states are all concatenated along the column axis. This latter behavior will soon be deprecated.

我想知道如何将多层 RNN 与 GRU 单元一起使用,这是我目前的代码:

def _run_rnn(self, inputs):
# embedded inputs are passed in here
self.initial_state = tf.zeros([self._batch_size, self._hidden_size], tf.float32)
cell = tf.nn.rnn_cell.GRUCell(self._hidden_size)
cell = tf.nn.rnn_cell.DropoutWrapper(cell, output_keep_prob=self._dropout_placeholder)
cell = tf.nn.rnn_cell.MultiRNNCell([cell] * self._num_layers, state_is_tuple=False)

outputs, last_state = tf.nn.dynamic_rnn(
cell = cell,
inputs = inputs,
sequence_length = self.sequence_length,
initial_state = self.initial_state
)

return outputs, last_state

我的输入查找单词 ID 并返回相应的嵌入向量。现在,使用上面的代码运行时出现以下错误:

ValueError:两个形状中的维度 1 必须相等,但对于输入形状为 [?]、[64,100]、[ 64,200]

我有 ? 的地方在我的占位符内:

def _add_placeholders(self):
self.input_placeholder = tf.placeholder(tf.int32, shape=[None, self._max_steps])
self.label_placeholder = tf.placeholder(tf.int32, shape=[None, self._max_steps])
self.sequence_length = tf.placeholder(tf.int32, shape=[None])
self._dropout_placeholder = tf.placeholder(tf.float32)

最佳答案

您的主要问题在于 initial_state 的设置。由于您的状态现在是一个元组,(更具体地说是 LSTMStateTuple,您不能直接将其分配给 tf.zeros。而是使用,

self.initial_state = cell.zero_state(self._batch_size, tf.float32)

看看 documentation了解更多。


要在代码中使用它,您需要在 feed_dict 中传递这个张量。做这样的事情,

state = sess.run(model.initial_state)
for batch in batches:
# Logic to add input placeholder in `feed_dict`
feed_dict[model.initial_state] = state
# Note I'm re-using `state` below
(loss, state) = sess.run([model.loss, model.final_state], feed_dict=feed_dict)

关于python - 在 tensorflow 0.12 中使用 MultiRNNCell,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/41476519/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com