gpt4 book ai didi

python - 获取 tensorflow 中dynamic_rnn的最后输出?

转载 作者:IT老高 更新时间:2023-10-28 21:11:28 27 4
gpt4 key购买 nike

我正在使用 dynamic_rnn 处理 MNIST 数据:

# LSTM Cell
lstm = rnn_cell.LSTMCell(num_units=200,
forget_bias=1.0,
initializer=tf.random_normal)

# Initial state
istate = lstm.zero_state(batch_size, "float")

# Get lstm cell output
output, states = rnn.dynamic_rnn(lstm, X, initial_state=istate)

# Output at last time point T
output_at_T = output[:, 27, :]

完整代码:http://pastebin.com/bhf9MgMe

lstm的输入是(batch_size, sequence_length, input_size)

因此 output_at_T 的维度为 (batch_size, sequence_length, num_units) 其中 num_units=200

我需要沿 sequence_length 维度获取最后一个输出。在上面的代码中,这被硬编码为 27。但是,我事先不知道 sequence_length,因为它可以在我的应用程序中逐批更改。

我试过了:

output_at_T = output[:, -1, :]

但它表示尚未实现负索引,我尝试使用占位符变量和常量(理想情况下,我可以将特定批处理的 sequence_length 输入其中);都没有用。

有什么方法可以在 tensorflow atm 中实现类似的东西?

最佳答案

您是否注意到 dynamic_rnn 有两个输出?

  1. 输出 1,我们称之为 h,在每个时间步都有所有输出(即 h_1、h_2 等),
  2. 输出 2,final_state,有两个元素:cell_state,以及批处理中每个元素的最后一个输出(只要您将序列长度输入到 dynamic_rnn)。

所以来自:

h, final_state= tf.dynamic_rnn( ..., sequence_length=[batch_size_vector], ... )

批处理中每个元素的最后状态是:

final_state.h

请注意,这包括批处理的每个元素的序列长度不同的情况,因为我们使用的是 sequence_length 参数。

关于python - 获取 tensorflow 中dynamic_rnn的最后输出?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/36817596/

27 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com