gpt4 book ai didi

python - Tensorflow:低级 LSTM 实现

转载 作者:太空宇宙 更新时间:2023-11-03 11:38:36 27 4
gpt4 key购买 nike

我正在寻找在 Tensorflow 中使用 LSTM 单元的 RNN 的低级实现。我已经实现了几个使用低级 API 的前馈网络。这对我理解 ANN 的内部工作原理有很大帮助。我可以对 RNN 做同样的事情还是建议使用 LSTM 单元的 Tensorflow 实现 (tf.nn.rnn_cell.BasicLSTMCell)?我没有在 Tensorflow 中找到 RNN 的任何低级实现。我在哪里可以找到这样的低级实现? Tensorflow 是为此设计的吗?我可以从哪里开始?我希望我的一些问题能在这里得到解答

最佳答案

1) 使用 tf.scan

RNN 的底层实现可以通过 tf.scan 实现功能。例如,对于 SimpleRNN,实现将类似于:

# our RNN variables
Wx = tf.get_variable(name='Wx', shape=[embedding_size, rnn_size])
Wh = tf.get_variable(name='Wh', shape=[rnn_size, rnn_size])
bias_rnn = tf.get_variable(name='brnn', initializer=tf.zeros([rnn_size]))


# single step in RNN
# simpleRNN formula is `tanh(WX+WH)`
def rnn_step(prev_hidden_state, x):
return tf.tanh(tf.matmul(x, Wx) + tf.matmul(prev_hidden_state, Wh) + bias_rnn)

# our unroll function
# notice that our inputs should be transpose
hidden_states = tf.scan(fn=rnn_step,
elems=tf.transpose(embed, perm=[1, 0, 2]),
initializer=tf.zeros([batch_size, rnn_size]))

# covert to previous shape
outputs = tf.transpose(hidden_states, perm=[1, 0, 2])

# extract last hidden
last_rnn_output = outputs[:, -1, :]

查看完整示例 here .

2) 使用AutoGraph

tf.scan 是一个可以实现的 for 循环 Auto-graph API 以及:

from tensorflow.python import autograph as ag

@ag.convert()
def f(x):
# ...
for ch in chars:
cell_output, (state, output) = cell.call(ch, (state, output))
hidden_outputs.append(cell_output)
hidden_outputs = autograph.stack(hidden_outputs)
# ...

查看带有签名 API 的完整示例 here .

3) 在 Numpy 中实现

如果您仍然需要深入内部来实现 RNN,请参阅 this使用 numpy 实现 RNN 的教程。

4) Keras 中的自定义 RNN 单元

参见 here .

关于python - Tensorflow:低级 LSTM 实现,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/54275098/

27 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com