gpt4 book ai didi

keras - 具有共享层的模型集成

转载 作者:行者123 更新时间:2023-12-04 17:16:23 25 4
gpt4 key购买 nike

在 keras 中,我想训练一组共享一些层的模型。它们具有以下形式:

x ---> f(x) ---> g_1(f(x))

x ---> f(x) ---> g_2(f(x))

...

x ---> f(x) ---> g_n(f(x))

这里 f(x) 是一些重要的共享层。 g_1 到 g_n 有它们特定的参数。

在每个训练阶段,数据 x 被送入 n 个网络之一,比如第 i 个网络。然后通过基于梯度的优化器最小化/减少 g_i(f(x)) 上的损失。我如何定义和训练这样的模型?

提前致谢!

最佳答案

您可以使用功能模型轻松地做到这一点。

一个小例子..你可以在它的基础上构建:

import numpy as np
from keras.models import Model
from keras.layers import Dense, Input

X = np.empty(shape=(1000,100))
Y1 = np.empty(shape=(1000))
Y2 = np.empty(shape=(1000,2))
Y3 = np.empty(shape=(1000,3))

inp = Input(shape=(100,))
dense_f1 = Dense(50)
dense_f2 = Dense(20)

f = dense_f2(dense_f1(inp))

dense_g1 = Dense(1)
g1 = dense_g1(f)

dense_g2 = Dense(2)
g2 = dense_g2(f)

dense_g3 = Dense(3)
g3 = dense_g3(f)


model = Model([inp], [g1, g2, g3])
model.compile(loss=['mse', 'binary_crossentropy', 'categorical_crossentropy'], optimizer='rmsprop')

model.summary()

model.fit([X], [Y1, Y2, Y3], nb_epoch=10)

编辑:

根据您的意见,您始终可以制作不同的模型,并根据您需要的训练方式自行编写训练循环。您可以在 model.summary() 中看到所有模型都共享初始层。这是示例的扩展
model1 = Model(inp, g1)
model1.compile(loss=['mse'], optimizer='rmsprop')
model2 = Model(inp, g2)
model2.compile(loss=['binary_crossentropy'], optimizer='rmsprop')
model3 = Model(inp, g3)
model3.compile(loss=['categorical_crossentropy'], optimizer='rmsprop')
model1.summary()
model2.summary()
model3.summary()

batch_size = 10
nb_epoch=10
n_batches = X.shape[0]/batch_size


for iepoch in range(nb_epoch):
for ibatch in range(n_batches):
x_batch = X[ibatch*batch_size:(ibatch+1)*batch_size]
if ibatch%3==0:
y_batch = Y1[ibatch*batch_size:(ibatch+1)*batch_size]
model1.train_on_batch(x_batch, y_batch)
elif ibatch%3==1:
y_batch = Y2[ibatch*batch_size:(ibatch+1)*batch_size]
model2.train_on_batch(x_batch, y_batch)
else:
y_batch = Y3[ibatch*batch_size:(ibatch+1)*batch_size]
model3.train_on_batch(x_batch, y_batch)

关于keras - 具有共享层的模型集成,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/41603357/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com