gpt4 book ai didi

python - 如何使用 Theano 在 Keras 中实现共享相同权重的 2 个相同的层分支

转载 作者:太空宇宙 更新时间:2023-11-03 14:24:08 36 4
gpt4 key购买 nike

我正在尝试实现两个相同的分支,在一定程度上共享相同的权重。您在此处看到的图形 enter image description here是我所拥有的简化模型。所以我有一个输入:一个负输入和一个正输入,从 conv1_1_x 到 Rpn 的所有层都应该具有相同的权重。到目前为止我尝试实现的是:

    def create_base_network(input_shape, branch, input_im, img_input, roi_input):
def creat_conv_model(input_shape):
branch = Sequential()
branch.add(Conv2D(64,filter_size,subsample = strides, input_shape=input_shape , activation='relu',kernel_initializer='glorot_uniform' ,name='conv1_1_'+str(branch)))
branch.add(Conv2D(64,filter_size, subsample = strides, activation='relu', kernel_initializer='glorot_uniform',name='conv1_2_1'+str(branch)))
branch.add(MaxPooling2D(pool_size=(2,2), strides=pool_stride, name='pool1_'+str(branch)))
branch.add(Conv2D(128,filter_size,subsample = strides, activation='relu', kernel_initializer='glorot_uniform',name='conv2_1_'+str(branch)))

return branch
shared_layers = creat_conv_model(input_shape)
rpn_output = rpn(shared_layers(input_im),9,branch)
model = Model([img_input, roi_input], rpn_output[:2])

return model


Branch_left = create_base_network((64, 64, 3), 1, img_input_left, img_input, roi_input)
Branch_right = create_base_network((64, 64, 3), 2, img_input_right, img_input, roi_input)

当我运行此程序时,出现以下错误:

RuntimeError: Graph disconnected: cannot obtain value for tensor /input_2 at layer "input_2". The following previous layers were accessed without issue: []

有人可以帮忙吗?

最佳答案

要使模型共享权重,您只需创建一次。您不能创建两个模型。

shared_model = creat_conv_model((64, 64, 3), left)

如果rpn也是要共享的模型,则只能创建一次:

rpn_model = create_rpn(...)

然后传递输入:

img_neg_out = shared_model(img_input_left)
img_neg_out = rpn_model(img_neg_out)

img_pos_out = shared_model(img_input_right)
img_pos_out = rpn_model(img_pos_out)

关于创建模型 branch_leftbranch_right,这取决于您想要做什么以及如何训练。

关于python - 如何使用 Theano 在 Keras 中实现共享相同权重的 2 个相同的层分支,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/47741562/

36 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com