gpt4 book ai didi

python - Tensorflow:如何在序列长度不同的 RNN 输出中添加偏差

转载 作者:太空宇宙 更新时间:2023-11-03 15:53:58 26 4
gpt4 key购买 nike

首先让我解释一下 RNN 的输入和目标值。我的数据集由序列组成(例如 4、7、1、23、42、69)。 RNN 被训练来预测每个序列中的下一个值。因此,除了最后一个值之外的所有值都是输入,并且除了第一个值之外的所有值都是目标值。每个值都表示为 1-HOT 向量。

我在 Tensorflow 中有一个 RNN,其中 RNN (tf.dynamic_rnn) 的输出通过前馈层发送。输入序列的长度不同,因此我使用sequence_length参数来指定批处理中每个序列的长度。 RNN 层的输出是每个时间步输出的张量。大多数序列具有相同的长度,但有些序列更短。当发送较短的序列时,我会得到额外的全零向量(作为填充)。

问题是我想通过前馈层发送 RNN 层的输出。如果我在此前馈层中添加偏差,则附加的全零向量将变为非零。没有偏差,只有权重,这工作得很好,因为全零向量不受乘法的影响。因此,在没有偏差的情况下,我也可以将目标向量设置为全零,因此它们不会影响向后传递。但如果添加偏差,我不知道要在填充/虚拟目标向量中放入什么。

所以网络看起来像这样:

[INPUT (1-HOT vectors, one vector for each value in the sequence)]
V
[GRU layer (smaller size than the input layer)]
V
[Feedforward layer (outputs vectors of the same size as the input)]

这是代码:

# [batch_size, max_sequence_length, size of 1-HOT vectors]
x = tf.placeholder(tf.float32, [None, max_length, n_classes])
y = tf.placeholder(tf.int32, [None, max_length, n_classes])
session_length = tf.placeholder(tf.int32, [None])

outputs, state = rnn.dynamic_rnn(
rnn_cell.GRUCell(num_hidden),
x,
dtype=tf.float32,
sequence_length=session_length
)

layer = {'weights':tf.Variable(tf.random_normal([n_hidden, n_classes])),
'biases':tf.Variable(tf.random_normal([n_classes]))}

# Flatten to apply same weights to all timesteps
outputs = tf.reshape(outputs, [-1, n_hidden])

prediction = tf.matmul(output, layer['weights']) # + layer['bias']

error = tf.nn.softmax_cross_entropy_with_logits(prediction,y)

最佳答案

您可以添加偏差,但从损失函数中屏蔽掉不相关的序列元素。

查看example来自 im2txt 项目:

weights = tf.to_float(tf.reshape(self.input_mask, [-1])) # these are the masks

# Compute losses.
losses = tf.nn.sparse_softmax_cross_entropy_with_logits(logits, targets)
batch_loss = tf.div(tf.reduce_sum(tf.mul(losses, weights)),
tf.reduce_sum(weights),
name="batch_loss") # Here the irrelevant sequence elements are masked out

此外,要生成掩码,请参阅函数 batch_with_dynamic_pad在同一个项目中,在 ops/inputs 下

关于python - Tensorflow:如何在序列长度不同的 RNN 输出中添加偏差,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/40973868/

26 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com