gpt4 book ai didi

python - Tensorflow LSTM - LSTM 单元上的矩阵乘法

转载 作者:行者123 更新时间:2023-11-28 18:23:16 26 4
gpt4 key购买 nike

我正在 Tensorflow 中制作 LSTM 神经网络。

输入张量大小为 92。

import tensorflow as tf
from tensorflow.contrib import rnn
import data

test_x, train_x, test_y, train_y = data.get()

# Parameters
learning_rate = 0.001
epochs = 100
batch_size = 64
display_step = 10

# Network Parameters
n_input = 28 # input size
n_hidden = 128 # number of hidden layers
n_classes = 20 # output size

# Placeholders
x = tf.placeholder(dtype=tf.float32, shape=[None, n_input])
y = tf.placeholder(dtype=tf.float32, shape=[None, n_classes])

# Network
def LSTM(x):
W = tf.Variable(tf.random_normal([n_hidden, n_classes]), dtype=tf.float32) # weights
b = tf.Variable(tf.random_normal([n_classes]), dtype=tf.float32) # biases

x_shape = 92

x = tf.transpose(x)
x = tf.reshape(x, [-1, n_input])
x = tf.split(x, x_shape)

lstm = rnn.BasicLSTMCell(
num_units=n_hidden,
forget_bias=1.0
)
outputs, states = rnn.static_rnn(
cell=lstm,
inputs=x,
dtype=tf.float32
)

output = tf.matmul( outputs[-1], W ) + b

return output

# Train Network
def train(x):
prediction = LSTM(x)

with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
output = sess.run(prediction, feed_dict={"x": train_x})
print(output)

train(x)

我没有收到任何错误,但我正在输入一个大小为 92 的张量,并且 LSTM 函数中的矩阵乘法返回一个包含一个结果向量的列表,当所需数量为 92 时,每个结果向量一个输入。

问题是我的矩阵只乘以输出数组中的最后一项吗?像这样:

output = tf.matmul( outputs[-1], W ) + b

代替:

output = tf.matmul( outputs, W ) + b

这是我在执行后者时遇到的错误:

ValueError: Shape must be rank 2 but is rank 3 for 'MatMul' (op: 'MatMul') with input shapes: [92,?,128], [128,20].

最佳答案

static_rnn 用于制作最简单的循环神经网络。 Here's the tf documentation .所以它的输入应该是一个张量序列。假设您想输入 4 个单词,分别是“Hi”、“how”、“Are”、“you”。因此,您的输入占位符应由对应于每个单词的四个 n(每个输入向量的大小)维向量组成。

我认为您的占位符有问题。您应该使用 RNN 的输入数量对其进行初始化。 28 是每个向量中的维数。我相信 92 是序列的长度。 (更像是 92 lstm 单元格)

在输出列表中,您将获得一组等于序列长度的向量,每个向量的大小等于隐藏单元的数量。

关于python - Tensorflow LSTM - LSTM 单元上的矩阵乘法,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/43402017/

26 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com