作者热门文章
- html - 出于某种原因,IE8 对我的 Sass 文件中继承的 html5 CSS 不友好?
- JMeter 在响应断言中使用 span 标签的问题
- html - 在 :hover and :active? 上具有不同效果的 CSS 动画
- html - 相对于居中的 html 内容固定的 CSS 重复背景?
我想创建一个使用注意力机制的多层动态 RNN 解码器。为此,我首先创建一个注意力机制:
attention_mechanism = BahdanauAttention(num_units=ATTENTION_UNITS,
memory=encoder_outputs,
normalize=True)
然后我使用 AttentionWrapper
来包装一个具有注意力机制的 LSTM 单元:
attention_wrapper = AttentionWrapper(cell=self._create_lstm_cell(DECODER_SIZE),
attention_mechanism=attention_mechanism,
output_attention=False,
alignment_history=True,
attention_layer_size=ATTENTION_LAYER_SIZE)
其中self._create_lstm_cell
定义如下:
@staticmethod
def _create_lstm_cell(cell_size):
return BasicLSTMCell(cell_size)
然后我做一些簿记(例如创建我的 MultiRNNCell
、创建初始状态、创建 TrainingHelper
等)
attention_zero = attention_wrapper.zero_state(batch_size=tf.flags.FLAGS.batch_size, dtype=tf.float32)
# define initial state
initial_state = attention_zero.clone(cell_state=encoder_final_states[0])
training_helper = TrainingHelper(inputs=self.y, # feed in ground truth
sequence_length=self.y_lengths) # feed in sequence lengths
layered_cell = MultiRNNCell(
[attention_wrapper] + [ResidualWrapper(self._create_lstm_cell(cell_size=DECODER_SIZE))
for _ in range(NUMBER_OF_DECODER_LAYERS - 1)])
decoder = BasicDecoder(cell=layered_cell,
helper=training_helper,
initial_state=initial_state)
decoder_outputs, decoder_final_state, decoder_final_sequence_lengths = dynamic_decode(decoder=decoder,
maximum_iterations=tf.flags.FLAGS.max_number_of_scans // 12,
impute_finished=True)
但我收到以下错误:AttributeError: 'LSTMStateTuple' object has no attribute 'attention'
。
向 MultiRNNCell 动态解码器添加注意机制的正确方法是什么?
最佳答案
您是否尝试过使用 attention wrapper由 tf.contrib 提供?
这是一个同时使用注意力包装器和 dropout 的示例:
cells = []
for i in range(n_layers):
cell = tf.contrib.rnn.LSTMCell(n_hidden, state_is_tuple=True)
cell = tf.contrib.rnn.AttentionCellWrapper(
cell, attn_length=40, state_is_tuple=True)
cell = tf.contrib.rnn.DropoutWrapper(cell,output_keep_prob=0.5)
cells.append(cell)
cell = tf.contrib.rnn.MultiRNNCell(cells, state_is_tuple=True)
init_state = cell.zero_state(batch_size, tf.float32)
关于tensorflow - 如何将 AttentionMechanism 与 MultiRNNCell 和 dynamic_decode 一起使用?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/44937105/
我想创建一个使用注意力机制的多层动态 RNN 解码器。为此,我首先创建一个注意力机制: attention_mechanism = BahdanauAttention(num_units=ATTENT
我是一名优秀的程序员,十分优秀!