gpt4 book ai didi

python - 计算每个时间步长的可变长度输出的成本

转载 作者:太空宇宙 更新时间:2023-11-03 14:19:21 26 4
gpt4 key购买 nike

我正在尝试使用 LSTM 单元和 Tensorflow 创建文本生成神经网络。我正在以时间主格式 [time_steps、batch_size、input_size] 的句子训练网络,并且我希望每个时间步都预测序列中的下一个单词。该序列在时间步长之前用空值填充,并且单独的占位符包含批处理中每个序列的长度。

有很多关于随时间反向传播概念的信息,但是我找不到有关 tensorflow 中用于可变长度序列成本计算的实际实现的任何信息。由于序列的末尾已填充,我假设我不想计算填充部分的成本。所以我需要一种方法将输出从第一个输出剪辑到序列的末尾。

这是我目前拥有的代码:

    outputs = []
states = []
cost = 0
for i in range(time_steps+1):
output, state = cell(X[i], state)
z1 = tf.matmul(output, dec_W1) + dec_b1
a1 = tf.nn.sigmoid(z1)
z2 = tf.matmul(a1, dec_W2) + dec_b2
a2 = tf.nn.softmax(z2)
outputs.append(a2)
states.append(state)
#== calculate cost
cost = cost + tf.nn.softmax_cross_entropy_with_logits(logits=z2, labels=y[i])
optimizer = tf.train.AdamOptimizer(0.001).minimize(cost)

此代码无需可变长度序列即可工作。但是,如果我在末尾添加了填充值,那么它也会计算填充部分的成本,这没有多大意义。

如何只计算序列长度上限之前的输出成本?

最佳答案

解决了!

在深入研究了很多示例之后(大多数都在更高级别的框架中,例如 Keras,这很痛苦),我发现您必须创建一个掩码!回想起来似乎很简单。

以下代码用于创建 1 和 0 的掩码,然后按元素将其与矩阵相乘(这将是成本值)

x = tf.placeholder(tf.float32)
seq = tf.placeholder(tf.int32)

def mask_by_length(input_matrix, length):
'''
Input matrix is a 2d tensor [batch_size, time_steps]
length is a 1d tensor
length refers to the length of input matrix axis 1
'''
length_transposed = tf.expand_dims(length, 1)

# Create range in order to compare length to
range = tf.range(tf.shape(input_matrix)[1])
range_row = tf.expand_dims(range, 0)

# Use the logical operations to create a mask
mask = tf.less(range_row, length_transposed)

# cast boolean to int to finalize mask
mask_result = tf.cast(mask, dtype=tf.float32)

# Element-wise multiplication to cancel out values in the mask
result = tf.multiply(mask_result, input_matrix)

return result

mask_values = mask_by_length(x, seq)

输入值(主要时间)[time_steps,batch_size]

[[ 0.71, 0.22, 1.42, -0.28, 0.99] [ 0.41、2.24、0.09、0.74、0.65]]

序列值[batch_size]

[2, 3]

输出(主要时间)[time_steps,batch_size]

[[ 0.71, 0.22, 0, 0, 0, ] [ 0.41, 2.24, 0.09, 0, 0, ]]

关于python - 计算每个时间步长的可变长度输出的成本,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/48040685/

26 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com