gpt4 book ai didi

python - 使用中间层作为输入和输出的 keras 模型

转载 作者:太空宇宙 更新时间:2023-11-04 01:54:16 25 4
gpt4 key购买 nike

我在 Keras(Tensoflow 后端)中有一个基本的 LSTM 自动编码器。模型结构如下:

l0 = Input(shape=(10, 2))
l1 = LSTM(16, activation='relu', return_sequences=True)(l0)
l2 = LSTM(8, activation='relu', return_sequences=False)(l1)
l3 = RepeatVector(10)(l2)
l4 = LSTM(8, activation='relu', return_sequences=True)(l3)
l5 = LSTM(16, activation='relu', return_sequences=True)(l4)
l6 = TimeDistributed(Dense(2))(l5)

我可以如下提取和编译编码器和自动编码器:

encoder = Model(l0, l2)
auto_encoder = Model(l0, l6)
auto_encoder.compile(optimizer='rmsprop', loss='mse', metrics=['mse'])

但是,当我尝试使用中间层制作模型时,例如:

decoder = Model(inputs=l3, outputs=l6)

我收到以下错误:

ValueError: Graph disconnected: cannot obtain value for tensor Tensor("input_12:0", shape=(?, 10, 2), dtype=float32) at layer "input_12". The following previous layers were accessed without issue: []

我不明白 l3l6 是如何相互断开连接的!我还尝试使用 get_layer(...).inputget_layer(...).output 制作解码器,但它会抛出同样的错误。

一个解释会对我有很大帮助。

最佳答案

问题是您尝试创建的模型没有输入层:

解码器=模型(输入=l3,输出=l6)

您可以创建一个具有正确形状的新 Input() 层,然后访问每个现有层。像这样:

input_layer = Input(shape=(8,))
l3 = auto_encoder.layers[3](input_layer)
l4 = auto_encoder.layers[4](l3)
l5 = auto_encoder.layers[5](l4)
l6 = auto_encoder.layers[6](l5)

decoder = Model(input_layer, l6)
decoder.summary()
Model: "model_2"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
input_14 (InputLayer) [(None, 8)] 0
_________________________________________________________________
repeat_vector_2 (RepeatVecto (None, 10, 8) 0
_________________________________________________________________
lstm_12 (LSTM) (None, 10, 8) 544
_________________________________________________________________
lstm_13 (LSTM) (None, 10, 16) 1600
_________________________________________________________________
time_distributed_1 (TimeDist (None, 10, 2) 34
=================================================================
Total params: 2,178
Trainable params: 2,178
Non-trainable params: 0

关于python - 使用中间层作为输入和输出的 keras 模型,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/57247888/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com