gpt4 book ai didi

python - 如何在 TensorFlow 中处理具有可变长度序列的批处理?

转载 作者:IT老高 更新时间:2023-10-28 20:37:57 26 4
gpt4 key购买 nike

我尝试使用 RNN(特别是 LSTM)进行序列预测。但是,我遇到了可变序列长度的问题。例如,

sent_1 = "I am flying to Dubain"
sent_2 = "I was traveling from US to Dubai"

我正在尝试使用基于此 Benchmark for building a PTB LSTM model 的简单 RNN 预测当前单词之后的下一个单词.

但是,num_steps 参数(用于展开到之前的隐藏状态)在每个 Tensorflow 的 epoch 中应该保持不变。基本上,批处理句子是不可能的,因为句子的长度不同。

 # inputs = [tf.squeeze(input_, [1])
# for input_ in tf.split(1, num_steps, inputs)]
# outputs, states = rnn.rnn(cell, inputs, initial_state=self._initial_state)

这里,num_steps 需要在我的情况下为每个句子进行更改。我尝试了几种 hack,但似乎没有任何效果。

最佳答案

您可以使用以下描述的分桶和填充的想法:

     Sequence-to-Sequence Models

另外,创建RNN网络的rnn函数接受参数sequence_length。

例如,您可以创建相同大小的句子桶,用必要数量的零或代表零字的占位符填充它们,然后将它们与 seq_length = len(zero_words) 一起提供。

seq_length = tf.placeholder(tf.int32)
outputs, states = rnn.rnn(cell, inputs, initial_state=initial_state, sequence_length=seq_length)

sess = tf.Session()
feed = {
seq_length: 20,
#other feeds
}
sess.run(outputs, feed_dict=feed)

也看看这个 reddit 线程:

    Tensorflow basic RNN example with 'variable length' sequences

关于python - 如何在 TensorFlow 中处理具有可变长度序列的批处理?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/34670112/

26 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com