gpt4 book ai didi

Keras-组合模型如何仅提取自动编码器上的部分

转载 作者:行者123 更新时间:2023-12-05 07:06:01 25 4
gpt4 key购买 nike

我已经使用 KERAS 编译了 2 个模型(分类和自动编码器),我能够评估该模型并且按照下面的方式运行没有问题。

   model.compile(loss={'classification': 'categorical_crossentropy', 
'autoencoder': 'mean_squared_error'},
optimizer='adam',
metrics={'classification': 'accuracy'})

history = model.fit(x_train,
{'classification': y_train, 'autoencoder': x_train},
batch_size=300,
epochs=1,
validation_data= (x_test, {'classification': y_test}),
verbose=1)

第二部分要求我仅利用自动编码器上的模型部分,并可视化 8 个图像样本。请引用下面的代码,它无法运行,因为代码是针对整个模型的,我如何只提取自动编码器上的模型部分来绘制图像?

# Generate reconstructions
num_reconstructions = 8
samples = x_test[:num_reconstructions]
targets = y_test[:num_reconstructions]
reconstructions = model.autoencoder.predict(samples)

import numpy as np

# Plot reconstructions
for i in np.arange(0, num_reconstructions):
# Get the sample and the recoax = pp.subplot(111)nstruction
sample = samples[i][:, :, 0]
reconstruction = reconstructions[i][:, :, 0]
input_class = targets[i]
# Matplotlib preparations
fig, axes = plt.subplots(1, 2)
# Plot sample and reconstruciton
axes[0].imshow(sample)
axes[0].set_title('Original image')
axes[1].imshow(reconstruction)
axes[1].set_title('Reconstruction with Conv2DTranspose')
fig.suptitle(f'MNIST target = {input_class}')
plt.show()

我的网络架构师见下文:Keras combined model

  • 我知道一种方法是在网络架构之后重新训练一个只有自动编码器的模型,但那将是一个不同的模型,与之前评估的不同,损失/准确度对应于评估的自动编码器/分类在问题开始时在一起。

最佳答案

这可以毫无问题地完成

我重新建议您模型的一个版本:

inp = Input((28,28,1))
enc = Conv2D(63, 3, padding='same')(inp)
enc = MaxPool2D()(enc)

clas = Flatten()(enc)
clas = Dense(1000)(clas)
clas = Dropout(0.3)(clas)
clas = Dense(10, activation='softmax', name='classification')(clas)

dec = Dense(1000)(enc)
dec = Conv2DTranspose(63, 3, padding='same')(dec)
dec = Conv2D(1, 3, padding='same')(dec)
dec = UpSampling2D(name='autoencoder')(dec)

model = Model(inp, [clas,dec])
model.compile(loss={'classification': 'sparse_categorical_crossentropy', 'autoencoder': 'mean_squared_error'},
optimizer='adam',
metrics={'classification': 'accuracy'})

我创建虚拟数据并拟合整个结构(分类 + 自动编码器)

X = np.random.uniform(0,1, (8,28,28,1))
y = np.random.randint(0,10, 8)

model.fit(X, [y,X], epochs=3)
prob, rec1 = model.predict(X)

拟合后我只提取自动编码器部分

autoenc = Model(inp, dec)
rec2 = autoenc.predict(X)

查看结果(rec1必须等于rec2)

(rec1 == rec2).all() # True ===> correct

这里是完整的运行示例:https://colab.research.google.com/drive/12CJzXHz8fdjAv2Jvmp5IV_D-UDiHWtTY?usp=sharing

关于Keras-组合模型如何仅提取自动编码器上的部分,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/62616490/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com