gpt4 book ai didi

python - 从 Keras 中经过训练的自动编码器模型中获取解码器

转载 作者:行者123 更新时间:2023-12-01 13:15:34 25 4
gpt4 key购买 nike

我正在训练一个深度自动编码器以将人脸映射到 128 维潜在空间,然后将它们解码回其原始的 128x128x3 格式。

我希望在训练自动编码器之后,我能够以某种方式“切片”自动编码器的后半部分,即负责将潜在空间 (128,) 映射到图像空间 (128, 128) 的解码器网络, 3) 通过使用功能性 Keras API 和 autoenc_model.get_layer()

这是我的 model 的相关图层:

INPUT_SHAPE=(128,128,3)
input_img = Input(shape=INPUT_SHAPE, name='enc_input')

#1
x = Conv2D(64, (3, 3), padding='same', activation='relu')(input_img)
x = BatchNormalization()(x)

//Many Conv2D, BatchNormalization(), MaxPooling() layers
.
.
.

#Flatten
fc_input = Flatten(name='enc_output')(x)

y = Dropout(DROP_RATE)(fc_input)
y = Dense(128, activation='relu')(y)
y = Dropout(DROP_RATE)(y)
fc_output = Dense(128, activation='linear')(y)

#Reshape
decoder_input = Reshape((8, 8, 2), name='decoder_input')(fc_output)

#Decoder part

#UnPooling-1
z = UpSampling2D()(decoder_input)
//Many Conv2D, BatchNormalization, UpSampling2D layers
.
.
.
#16
decoder_output = Conv2D(3, (3, 3), padding='same', activation='linear', name='decoder_output')(z)

autoenc_model = Model(input_img, decoder_output)

here是包含整个模型架构的笔记本。

为了从经过训练的自动编码器中获取解码器网络,我尝试使用:

dec_model = Model(inputs=autoenc_model.get_layer('decoder_input').input, outputs=autoenc_model.get_layer('decoder_output').output)

dec_model = Model(autoenc_model.get_layer('decoder_input'), autoenc_model.get_layer('decoder_output'))

这两个似乎都不起作用。

我需要从自动编码器中提取解码器层,因为我想先训练整个自动编码器模型,然后独立使用编码器和解码器。

我在其他任何地方都找不到满意的答案。 Keras blog article关于构建自动编码器仅涵盖如何为 2 层自动编码器提取解码器。

解码器输入/输出形状应为:(128, ) 和 (128, 128, 3),分别是“decoder_input”的输入形状和“decoder_output”层的输出形状。

最佳答案

需要做一些改变:

z = UpSampling2D()(decoder_input)

direct_input = Input(shape=(8,8,2), name='d_input')
#UnPooling-1
z = UpSampling2D()(direct_input)

autoenc_model = Model(input_img, decoder_output)

dec_model = Model(direct_input, decoder_output)
autoenc_model = Model(input_img, dec_model(decoder_input))

现在,您可以在自动编码器上训练并使用解码器进行预测。

import numpy as np
autoenc_model.fit(np.ones((5,128,128,3)), np.ones((5,128,128,3)))
dec_model.predict(np.ones((1,8,8,2)))

你也可以引用这个独立的例子: https://github.com/keras-team/keras/blob/master/examples/variational_autoencoder.py

关于python - 从 Keras 中经过训练的自动编码器模型中获取解码器,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/55458306/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com