作者热门文章
- html - 出于某种原因,IE8 对我的 Sass 文件中继承的 html5 CSS 不友好?
- JMeter 在响应断言中使用 span 标签的问题
- html - 在 :hover and :active? 上具有不同效果的 CSS 动画
- html - 相对于居中的 html 内容固定的 CSS 重复背景?
我最近发现 LayerNormBasicLSTMCell 是实现了层归一化和 dropout 的 LSTM 版本。因此,我将使用 LSTMCell 的原始代码替换为 LayerNormBasicLSTMCell。这一变化不仅将测试准确率从约 96% 降低到约 92%,而且训练时间也更长(约 33 小时)(原始训练时间为约 6 小时)。所有参数都相同:epoch 数 (10)、堆叠层数 (3)、隐藏向量大小数 (250)、drop out keep prob (0.5),...硬件也相同。
我的问题是:我在这里做错了什么?
我的原始模型(使用 LSTMCell):
# Batch normalization of the raw input
tf_b_VCCs_AMs_BN1 = tf.layers.batch_normalization(
tf_b_VCCs_AMs, # the input vector, size [#batches, #time_steps, 2]
axis=-1, # axis that should be normalized
training=Flg_training, # Flg_training = True during training, and False during test
trainable=True,
name="Inputs_BN"
)
# Bidirectional dynamic stacked LSTM
##### The part I changed in the new model (start) #####
dropcells = []
for iiLyr in range(3):
cell_iiLyr = tf.nn.rnn_cell.LSTMCell(num_units=250, state_is_tuple=True)
dropcells.append(tf.nn.rnn_cell.DropoutWrapper(cell=cell_iiLyr, output_keep_prob=0.5))
##### The part I changed in the new model (end) #####
MultiLyr_cell = tf.nn.rnn_cell.MultiRNNCell(cells=dropcells, state_is_tuple=True)
outputs, states = tf.nn.bidirectional_dynamic_rnn(
cell_fw=MultiLyr_cell,
cell_bw=MultiLyr_cell,
dtype=tf.float32,
sequence_length=tf_b_lens, # the actual lengths of the input sequences (tf_b_VCCs_AMs_BN1)
inputs=tf_b_VCCs_AMs_BN1,
scope = "BiLSTM"
)
我的新模型(使用 LayerNormBasicLSTMCell):
...
dropcells = []
for iiLyr in range(3):
cell_iiLyr = tf.contrib.rnn.LayerNormBasicLSTMCell(
num_units=250,
forget_bias=1.0,
activation=tf.tanh,
layer_norm=True,
norm_gain=1.0,
norm_shift=0.0,
dropout_keep_prob=0.5
)
dropcells.append(cell_iiLyr)
...
最佳答案
也许应该为 dropout_keep_prob
分配一个占位符而不是常量值。尝试在训练时分配 0.5
,在推理时分配 1.0
。只是猜测。
关于tensorflow - 为什么 LayerNormBasicLSTMCell 比 LSTMCell 慢得多且准确度低?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/45150101/
我最近发现 LayerNormBasicLSTMCell 是实现了层归一化和 dropout 的 LSTM 版本。因此,我将使用 LSTMCell 的原始代码替换为 LayerNormBasicLST
我的tensorflow版本是1.0.0。当我使用 tf.contrib.rnn.GRUCell(n_hidden_units) 正常运行时,但使用 tf.contrib.rnn.LayerNor
我是一名优秀的程序员,十分优秀!