gpt4 book ai didi

python - 层 lstm_5 的输入 0 与层 : expected ndim=3, 不兼容,发现 ndim=2

转载 作者:行者123 更新时间:2023-12-03 17:24:49 27 4
gpt4 key购买 nike

我正在尝试创建一个图像字幕模型。你能帮忙解决这个错误吗? input1 是图像向量,input2 是字幕序列。 32 是字幕长度。我想将图像向量与序列的嵌入连接起来,然后将其提供给解码器模型。


def define_model(vocab_size, max_length):
input1 = Input(shape=(512,))
input1 = tf.keras.layers.RepeatVector(32)(input1)
print(input1.shape)

input2 = Input(shape=(max_length,))
e1 = Embedding(vocab_size, 512, mask_zero=True)(input2)
print(e1.shape)

dec1 = tf.concat([input1,e1], axis=2)
print(dec1.shape)

dec2 = LSTM(512)(dec1)
dec3 = LSTM(256)(dec2)
dec4 = Dropout(0.2)(dec3)
dec5 = Dense(256, activation="relu")(dec4)
output = Dense(vocab_size, activation="softmax")(dec5)
model = tf.keras.Model(inputs=[input1, input2], outputs=output)
model.compile(loss="categorical_crossentropy", optimizer="adam")
print(model.summary())
return model

ValueError: Input 0 of layer lstm_5 is incompatible with the layer: expected ndim=3, found ndim=2. Full shape received: [None, 512]

最佳答案

当 LSTM 层以 2D 而不是 3D 形式获取输入时,会发生此错误。例如:

(64, 100)
正确的格式是 (n_samples, time_steps, features) :
(64, 5, 100)
在这种情况下,你做的错误是输入了 dec3 ,这是一个 LSTM 层,是 dec2 的输出,这也是一个 LSTM 层。默认情况下,参数 return_sequences在 LSTM 层中是 False .这意味着第一个 LSTM 返回了一个 2D 张量,它与下一个 LSTM 层不兼容。我通过设置 return_sequences=True 解决了您的问题在你的第一个 LSTM 层。
此外,这一行有一个错误:
model = tf.keras.Model(inputs=[input1, input2], outputs=output)
input1不是输入层,因为您重新分配了它。看:
input1 = Input(shape=(512,))
input1 = tf.keras.layers.RepeatVector(32)(input1)
我重命名了第二个 e0 ,与您命名变量的方式一致。
现在,一切正常:
import tensorflow as tf
from tensorflow.keras.layers import *
from tensorflow.keras import Input

vocab_size, max_length = 1000, 32

input1 = Input(shape=(128))
e0 = tf.keras.layers.RepeatVector(32)(input1)
print(input1.shape)

input2 = Input(shape=(max_length,))
e1 = Embedding(vocab_size, 128, mask_zero=True)(input2)
print(e1.shape)

dec1 = Concatenate()([e0, e1])
print(dec1.shape)

dec2 = LSTM(16, return_sequences=True)(dec1)
dec3 = LSTM(16)(dec2)
dec4 = Dropout(0.2)(dec3)
dec5 = Dense(32, activation="relu")(dec4)
output = Dense(vocab_size, activation="softmax")(dec5)
model = tf.keras.Model(inputs=[input1, input2], outputs=output)
model.compile(loss="categorical_crossentropy", optimizer="adam")
print(model.summary())
Model: "model_2"
_________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
=================================================================================
input_24 (InputLayer) [(None, 128)] 0
_________________________________________________________________________________

input_25 (InputLayer) [(None, 32)] 0
_________________________________________________________________________________

repeat_vector_12 (RepeatVector) (None, 32, 128) 0 input_24[0][0]
_________________________________________________________________________________

embedding_11 (Embedding) (None, 32, 128) 128000 input_25[0][0]
_________________________________________________________________________________
concatenate_7 (Concatenate) (None, 32, 256) 0 repeat_vector_12[0][0]
embedding_11[0][0]
_________________________________________________________________________________
lstm_12 (LSTM) (None, 32, 16) 17472 concatenate_7[0][0]
_________________________________________________________________________________
lstm_13 (LSTM) (None, 16) 2112 lstm_12[0][0]
_________________________________________________________________________________
dropout_2 (Dropout) (None, 16) 0 lstm_13[0][0]
_________________________________________________________________________________
dense_4 (Dense) (None, 32) 544 dropout_2[0][0]
_________________________________________________________________________________
dense_5 (Dense) (None, 1000) 33000 dense_4[0][0]
=================================================================================
Total params: 181,128
Trainable params: 181,128
Non-trainable params: 0
_________________________________________________________________________________

关于python - 层 lstm_5 的输入 0 与层 : expected ndim=3, 不兼容,发现 ndim=2,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/62395559/

27 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com