gpt4 book ai didi

python - 如何在 TensorFlow 中打印出 LSTM 门的值?

转载 作者:太空狗 更新时间:2023-10-30 01:36:54 24 4
gpt4 key购买 nike

我将 TensorFlow LSTM 用于语言模型(我有一个单词序列并想预测下一个单词),并且在运行语言模型时,我想打印出 forget 的值、输入、转换和输出门在每一步。我该怎么做?

来自检查 https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/rnn/python/ops/rnn_cell.py 中的代码,我看到 LayerNormBasicLSTMCell 类有一个 call 方法,其中包含我要打印的 i, j, f, o 变量。

  def call(self, inputs, state):
"""LSTM cell with layer normalization and recurrent dropout."""
c, h = state
args = array_ops.concat([inputs, h], 1)
concat = self._linear(args)

i, j, f, o = array_ops.split(value=concat, num_or_size_splits=4, axis=1)
if self._layer_norm:
i = self._norm(i, "input")
j = self._norm(j, "transform")
f = self._norm(f, "forget")
o = self._norm(o, "output")

g = self._activation(j)
if (not isinstance(self._keep_prob, float)) or self._keep_prob < 1:
g = nn_ops.dropout(g, self._keep_prob, seed=self._seed)

new_c = (c * math_ops.sigmoid(f + self._forget_bias)
+ math_ops.sigmoid(i) * g)
if self._layer_norm:
new_c = self._norm(new_c, "state")
new_h = self._activation(new_c) * math_ops.sigmoid(o)

new_state = core_rnn_cell.LSTMStateTuple(new_c, new_h)
return new_h, new_state

但是,有没有一种简单的方法可以让我打印出这些变量?或者我是否必须在运行 LTSM 的脚本中基本上重新创建此方法中的相关代码行?

最佳答案

我曾经在 git issue 中问过类似的问题。响应是原始单元格仅返回 ch(这也是每一步的输出 y)。如果要获取内部变量,需要自己动手。

这是链接:https://github.com/tensorflow/tensorflow/issues/5731

关于python - 如何在 TensorFlow 中打印出 LSTM 门的值?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/43884057/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com