gpt4 book ai didi

tensorflow - 如何在没有嵌入的情况下使用 tensorflow seq2seq?

转载 作者:行者123 更新时间:2023-12-04 18:57:20 24 4
gpt4 key购买 nike

我一直在使用 tensorflow 在 LSTM 上进行时间序列预测。现在,我想尝试序列到序列(seq2seq)。在官方网站上有一个教程展示了 NMT with embeddings 。那么,如何在没有嵌入的情况下使用这个新的 seq2seq 模块呢? (直接使用时间序列“序列”)。

# 1. Encoder
encoder_cell = tf.contrib.rnn.BasicLSTMCell(LSTM_SIZE)
encoder_outputs, encoder_state = tf.nn.static_rnn(
encoder_cell,
x,
dtype=tf.float32)

# Decoder
decoder_cell = tf.nn.rnn_cell.BasicLSTMCell(LSTM_SIZE)


helper = tf.contrib.seq2seq.TrainingHelper(
decoder_emb_inp, decoder_lengths, time_major=True)


decoder = tf.contrib.seq2seq.BasicDecoder(
decoder_cell, helper, encoder_state)

# Dynamic decoding
outputs, _ = tf.contrib.seq2seq.dynamic_decode(decoder)
outputs = outputs[-1]

# output is result of linear activation of last layer of RNN
weight = tf.Variable(tf.random_normal([LSTM_SIZE, N_OUTPUTS]))
bias = tf.Variable(tf.random_normal([N_OUTPUTS]))
predictions = tf.matmul(outputs, weight) + bias

如果我使用 input_seq=x 和 output_seq=label,TrainingHelper() 的参数应该是什么?

解码器_emb_inp ???
解码器长度???

其中 input_seq 是序列的前 8 个点,而 output_seq 是序列的最后 2 个点。
提前感谢!

最佳答案

我使用非常基本的 InferenceHelper 让它在不嵌入的情况下工作:

inference_helper = tf.contrib.seq2seq.InferenceHelper(
sample_fn=lambda outputs: outputs,
sample_shape=[dim],
sample_dtype=dtypes.float32,
start_inputs=start_tokens,
end_fn=lambda sample_ids: False)

我的输入是浮点数,形状为 [batch_size, time, dim] .对于下面的示例 dim将是 1,但这可以很容易地扩展到更多维度。这是代码的相关部分:
projection_layer = tf.layers.Dense(
units=1, # = dim
kernel_initializer=tf.truncated_normal_initializer(
mean=0.0, stddev=0.1))

# Training Decoder
training_decoder_output = None
with tf.variable_scope("decode"):
# output_data doesn't exist during prediction phase.
if output_data is not None:
# Prepend the "go" token
go_tokens = tf.constant(go_token, shape=[batch_size, 1, 1])
dec_input = tf.concat([go_tokens, target_data], axis=1)

# Helper for the training process.
training_helper = tf.contrib.seq2seq.TrainingHelper(
inputs=dec_input,
sequence_length=[output_size] * batch_size)

# Basic decoder
training_decoder = tf.contrib.seq2seq.BasicDecoder(
dec_cell, training_helper, enc_state, projection_layer)

# Perform dynamic decoding using the decoder
training_decoder_output = tf.contrib.seq2seq.dynamic_decode(
training_decoder, impute_finished=True,
maximum_iterations=output_size)[0]

# Inference Decoder
# Reuses the same parameters trained by the training process.
with tf.variable_scope("decode", reuse=tf.AUTO_REUSE):
start_tokens = tf.constant(
go_token, shape=[batch_size, 1])

# The sample_ids are the actual output in this case (not dealing with any logits here).
# My end_fn is always False because I'm working with a generator that will stop giving
# more data. You may extend the end_fn as you wish. E.g. you can append end_tokens
# and make end_fn be true when the sample_id is the end token.
inference_helper = tf.contrib.seq2seq.InferenceHelper(
sample_fn=lambda outputs: outputs,
sample_shape=[1], # again because dim=1
sample_dtype=dtypes.float32,
start_inputs=start_tokens,
end_fn=lambda sample_ids: False)

# Basic decoder
inference_decoder = tf.contrib.seq2seq.BasicDecoder(dec_cell,
inference_helper,
enc_state,
projection_layer)

# Perform dynamic decoding using the decoder
inference_decoder_output = tf.contrib.seq2seq.dynamic_decode(
inference_decoder, impute_finished=True,
maximum_iterations=output_size)[0]

看看 this question .我还发现了这个 tutorial对理解 seq2seq 模型非常有用,尽管它确实使用了嵌入。所以更换他们的 GreedyEmbeddingHelper通过 InferenceHelper就像我在上面发布的一样。

附:我将完整代码发布在 https://github.com/Andreea-G/tensorflow_examples

关于tensorflow - 如何在没有嵌入的情况下使用 tensorflow seq2seq?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/49134432/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com