gpt4 book ai didi

python - tf.nn.dynamic_rnn() "outputs"与 "state"的概念理解

转载 作者:太空宇宙 更新时间:2023-11-04 01:54:11 26 4
gpt4 key购买 nike

上下文

我正在阅读 Hands on ML 的第二部分并且正在寻找关于何时使用“输出”以及何时在 RNN 的损失计算中使用“状态”的一些清晰度。

在书中(对于那些拥有该书的人,第 396 页),作者说,“请注意,全连接层连接到 states 张量,它仅包含 RNN 的最终状态,”引用到展开超过 28 个步骤的序列分类器。自 states变量将有 len(states) == <number_of_hidden_layers> ,在构建深度 RNN 时,我一直使用 states[-1] 来仅连接到最后一层的最终状态。例如:

# hidden_layer_architecture = list of ints defining n_neurons in each layer
# example: hidden_layer_architecture = [100 for _ in range(5)]
layers = []
for layer_id, n_neurons in enumerate(hidden_layer_architecture):

hidden_layer = tf.contrib.rnn.BasicRNNCell(n_neurons,
activation=tf.nn.tanh,
name=f'hidden_layer_{layer_id}')

layers.append(hidden_layer)

recurrent_hidden_layers = tf.contrib.rnn.MultiRNNCell(layers)
outputs, states = tf.nn.dynamic_rnn(recurrent_hidden_layers,
X_, dtype=tf.float32)

logits = tf.layers.dense(states[-1], n_outputs, name='outputs')

鉴于作者之前的陈述,这按预期工作。但是,我不明白什么时候会使用 outputs变量(tf.nn.dynamic_rnn() 的第一个输出)

我看过this question ,它很好地回答了细节问题,并提到,“如果您只对单元格的最后一个输出感兴趣,您可以只对时间维度进行切片以仅选择最后一个元素(例如 outputs[:, -1, :] )。 “我推断这意味着类似于 states[-1] == outputs[:, -1, :] 的意思,这在测试时是错误的。为什么不是这样呢?如果输出是单元格在每个时间步长的输出,为什么不是这样呢?一般来说...

问题

什么时候使用 outputs来自 tf.nn.dynamic_rnn() 的变量在损失函数中,什么时候使用 states多变的?这如何改变网络的抽象架构?

任何清晰度将不胜感激。

最佳答案

这基本上把它分解了:

outputs:RNN 顶层输出的完整序列。这意味着,如果您使用 MultiRNNCell,这将只是 top 单元格;这里没有来自下层牢房的任何东西。
一般来说,使用自定义的 RNNCell 实现,这几乎可以是任何东西,但是几乎所有标准单元都在这里返回 states 的序列,但是你也可以写一个自定义单元格,在将其作为输出返回之前对状态序列执行某些操作(例如线性变换)。

state(注意这是文档中的称呼,不是states)是last<的完整状态/em> 时间步长。一个重要的区别是,在 MultiRNNCell 的情况下,这将包含序列中所有 单元格的最终状态,而不仅仅是顶部的单元格!此外,此输出的精确格式/类型在很大程度上取决于所使用的 RNNCell(例如,它可能是一个张量,或一个张量元组......)。

因此,如果您只关心 MultiRNNCell 中最后一个时间步的最顶层状态,您确实有两个应该相同的选项,具体取决于个人喜好/"清晰度”:

  • outputs[:, -1, :](假设批处理主要格式)仅从顶级状态序列中提取最后一个时间步长。
  • state[-1] 仅从所有层的最终状态元组中提取顶级状态。

在其他情况下您可能没有此选择:

  • 如果您确实需要完整的序列输出,则需要使用outputs
  • 如果您需要 MultiRNNCell 中较低层的最终状态,您需要使用 state

至于相等性检查失败的原因:如果您实际使用了 ==,我相信这会检查明显不同的张量 对象 是否相等。您可以改为尝试检查两个对象的,以了解一些简单的玩具场景(微小的状态大小/序列长度)——它们应该相同。

关于python - tf.nn.dynamic_rnn() "outputs"与 "state"的概念理解,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/57270125/

26 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com