gpt4 book ai didi

python - tf.keras.layers.RNN 与 tf.keras.layers.StackedRNNCells : Tensorflow 2

转载 作者:行者123 更新时间:2023-12-02 05:50:40 27 4
gpt4 key购买 nike

我正在尝试在 Tensorflow 2.0 中实现多层 RNN 模型。尝试 tf.keras.layers.StackedRNNCells 和 tf.keras.layers.RNN 会产生相同的结果。任何人都可以帮我理解 tf.keras.layers.RNN 和 tf.keras.layers.StackedRNNCells 之间的区别吗?

# driving parameters
sz_batch = 128
sz_latent = 200
sz_sequence = 196
sz_feature = 2
n_units = 120
n_layers = 3

带有tf.keras.layers.RNN的多层RNN:

inputs = tf.keras.layers.Input(batch_shape=(sz_batch, sz_sequence, sz_feature))
cells = [tf.keras.layers.GRUCell(n_units) for _ in range(n_layers)]
outputs = tf.keras.layers.RNN(cells, stateful=True, return_sequences=True, return_state=False)(inputs)
outputs = tf.keras.layers.Dense(1)(outputs)
model = tf.keras.Model(inputs=inputs, outputs=outputs)
model.summary()

返回:

Model: "model_13"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
input_88 (InputLayer) [(128, 196, 2)] 0
_________________________________________________________________
rnn_61 (RNN) (128, 196, 120) 218880
_________________________________________________________________
dense_19 (Dense) (128, 196, 1) 121
=================================================================
Total params: 219,001
Trainable params: 219,001
Non-trainable params: 0

带有tf.keras.layers.RNNtf.keras.layers.StackedRNNCells的多层RNN:

inputs = tf.keras.layers.Input(batch_shape=(sz_batch, sz_sequence, sz_feature))
cells = [tf.keras.layers.GRUCell(n_units) for _ in range(n_layers)]
outputs = tf.keras.layers.RNN(tf.keras.layers.StackedRNNCells(cells),
stateful=True,
return_sequences=True,
return_state=False)(inputs)
outputs = tf.keras.layers.Dense(1)(outputs)
model = tf.keras.Model(inputs=inputs, outputs=outputs)
model.summary()

返回:

Model: "model_14"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
input_89 (InputLayer) [(128, 196, 2)] 0
_________________________________________________________________
rnn_62 (RNN) (128, 196, 120) 218880
_________________________________________________________________
dense_20 (Dense) (128, 196, 1) 121
=================================================================
Total params: 219,001
Trainable params: 219,001
Non-trainable params: 0

最佳答案

如果您给它一个单元列表或元组,tf.keras.layers.RNN 将使用 tf.keras.layers.StackedRNNCells。这是在 https://github.com/tensorflow/tensorflow/blob/v2.1.0/tensorflow/python/keras/layers/recurrent.py#L390 中完成的

关于python - tf.keras.layers.RNN 与 tf.keras.layers.StackedRNNCells : Tensorflow 2,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/60624960/

27 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com