gpt4 book ai didi

keras - 加载 vgg16_weights.h5 时如何使用 vgg-net?

转载 作者:行者123 更新时间:2023-12-05 00:19:08 28 4
gpt4 key购买 nike

我使用 keras 的 VGG-16 Net。这是detail

我的问题是如何使用这个网络进行微调,我必须使用这个网络的224 * 224的图像大小吗?当我使用这个网络时,我必须使用 1000 个类?
如果我不使用 1000 个类,则会导致错误

Exception: Layer shape (4096L, 10L) not compatible with weight shape (4096, 1000).



求帮助,谢谢!

最佳答案

我贴出详细答案in this issue如果你想看看。以下代码段将帮助您确定最后一层的尺寸:

from keras.models import Sequential, Graph
from keras.layers import Convolution2D, ZeroPadding2D, MaxPooling2D
import keras.backend as K

img_width, img_height = 128, 128

# build the VGG16 network with our input_img as input
first_layer = ZeroPadding2D((1, 1), input_shape=(3, img_width, img_height))

model = Sequential()
model.add(first_layer)
model.add(Convolution2D(64, 3, 3, activation='relu', name='conv1_1'))
model.add(ZeroPadding2D((1, 1)))
model.add(Convolution2D(64, 3, 3, activation='relu', name='conv1_2'))
model.add(MaxPooling2D((2, 2), strides=(2, 2)))

model.add(ZeroPadding2D((1, 1)))
model.add(Convolution2D(128, 3, 3, activation='relu', name='conv2_1'))
model.add(ZeroPadding2D((1, 1)))
model.add(Convolution2D(128, 3, 3, activation='relu', name='conv2_2'))
model.add(MaxPooling2D((2, 2), strides=(2, 2)))

model.add(ZeroPadding2D((1, 1)))
model.add(Convolution2D(256, 3, 3, activation='relu', name='conv3_1'))
model.add(ZeroPadding2D((1, 1)))
model.add(Convolution2D(256, 3, 3, activation='relu', name='conv3_2'))
model.add(ZeroPadding2D((1, 1)))
model.add(Convolution2D(256, 3, 3, activation='relu', name='conv3_3'))
model.add(MaxPooling2D((2, 2), strides=(2, 2)))

model.add(ZeroPadding2D((1, 1)))
model.add(Convolution2D(512, 3, 3, activation='relu', name='conv4_1'))
model.add(ZeroPadding2D((1, 1)))
model.add(Convolution2D(512, 3, 3, activation='relu', name='conv4_2'))
model.add(ZeroPadding2D((1, 1)))
model.add(Convolution2D(512, 3, 3, activation='relu', name='conv4_3'))
model.add(MaxPooling2D((2, 2), strides=(2, 2)))

model.add(ZeroPadding2D((1, 1)))
model.add(Convolution2D(512, 3, 3, activation='relu', name='conv5_1'))
model.add(ZeroPadding2D((1, 1)))
model.add(Convolution2D(512, 3, 3, activation='relu', name='conv5_2'))
model.add(ZeroPadding2D((1, 1)))
model.add(Convolution2D(512, 3, 3, activation='relu', name='conv5_3'))
model.add(MaxPooling2D((2, 2), strides=(2, 2)))

# get the symbolic outputs of each "key" layer (we gave them unique names).
layer_dict = dict([(layer.name, layer) for layer in model.layers])

# load the weights

import h5py

weights_path = 'vgg16_weights.h5'

f = h5py.File(weights_path)
for k in range(f.attrs['nb_layers']):
if k >= len(model.layers):
# we don't look at the last (fully-connected) layers in the savefile
break
g = f['layer_{}'.format(k)]
weights = [g['param_{}'.format(p)] for p in range(g.attrs['nb_params'])]
model.layers[k].set_weights(weights)
f.close()
print('Model loaded.')

# Here is what you want:

graph_m = Graph()
graph_m.add_input('my_inp', input_shape=(3, img_width, img_height))
graph_m.add_node(model, name='your_model', input='my_inp')
graph_m.add_node(Flatten(), name='Flatten', input='your_model')
graph_m.add_node(Dense(4096, activation='relu'), name='Dense1', input='Flatten')
graph_m.add_node(Dropout(0.5), name='Dropout1', input='Dense1')
graph_m.add_node(Dense(4096, activation='relu'), name='Dense2', input='Dropout1')
graph_m.add_node(Dropout(0.5), name='Dropout2', input='Dense2')
graph_m.add_node(Dense(10, activation='softmax'), name='Final', input='Dropout2')
graph_m.add_output(name='out1', input='Final')
sgd = SGD(lr=0.1, decay=1e-6, momentum=0.9, nesterov=True)
graph_m.compile(optimizer=sgd, loss={'out1': 'categorical_crossentropy'})

请注意,您可以卡住特征提取层的训练,只微调最后的全连接层。
来自 doc ,你只需要添加 trainable = False卡住层的训练。
前冷冻:

...
model.add(Convolution2D(512, 3, 3, activation='relu', name='conv5_1', trainable=False))
...

可训练:

...
model.add(Convolution2D(512, 3, 3, activation='relu', name='conv5_1', trainable=True))
...
trainableTrue默认情况下,如果您不了解该功能,则会发生某些事情...

关于keras - 加载 vgg16_weights.h5 时如何使用 vgg-net?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/36554827/

28 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com