gpt4 book ai didi

python - 减少卷积神经网络的内存需求

转载 作者:行者123 更新时间:2023-11-30 09:26:51 24 4
gpt4 key购买 nike

我正在尝试为图像创建一个卷积神经网络。我目前有 17 个类别的大约 136 张图像(稍后将添加更多图像)。

每个图像均采用形状为 (330, 330, 3)numpy.array 形式。

我正在使用以下网络代码:

batch_size = 64
nb_classes = 17
nb_epoch = 2
img_rows = 330
img_cols = 330
nb_filters = 16
nb_conv = 3 # convolution kernel size
nb_pool = 2

model = Sequential()
# 1st conv layer:
model.add(Convolution2D(
nb_filters, (nb_conv, nb_conv),
padding="valid",
input_shape=(img_rows, img_cols, 3),
data_format='channels_last', ))
model.add(Activation('relu'))

# 2nd conv layer:
model.add(Convolution2D(nb_filters, (nb_conv, nb_conv), data_format='channels_last'))
model.add(Activation('relu'))

# maxpooling layer:
model.add(MaxPooling2D(pool_size=(nb_pool, nb_pool), data_format="channels_last"))
model.add(Dropout(0.25))

# 2 FC layers:
model.add(Flatten())
model.add(Dense(128))
model.add(Activation('relu'))

model.add(Dropout(0.5))
model.add(Dense(nb_classes))
model.add(Activation('softmax'))

model.compile(loss='categorical_crossentropy', optimizer='adadelta')

model.summary()

model.fit(X_train, y_train, batch_size=batch_size, epochs=nb_epoch, verbose=1 )

但是,在启动第一个纪元后不久,它会给出一条消息“使用了 10% 的系统内存”。它变得没有响应,我必须硬重启它。

我可以采取哪些步骤或对代码进行更改来减少内存需求?

最佳答案

通过查看 model.summary() 的输出,您可以找出导致此问题的原因(即哪些层参数过多):

Layer (type)                 Output Shape              Param #   
=================================================================
conv2d_189 (Conv2D) (None, 328, 328, 16) 448
_________________________________________________________________
activation_189 (Activation) (None, 328, 328, 16) 0
_________________________________________________________________
conv2d_190 (Conv2D) (None, 326, 326, 16) 2320
_________________________________________________________________
activation_190 (Activation) (None, 326, 326, 16) 0
_________________________________________________________________
max_pooling2d_9 (MaxPooling2 (None, 163, 163, 16) 0
_________________________________________________________________
dropout_3 (Dropout) (None, 163, 163, 16) 0
_________________________________________________________________
flatten_1 (Flatten) (None, 425104) 0
_________________________________________________________________
dense_5 (Dense) (None, 128) 54413440
_________________________________________________________________
activation_191 (Activation) (None, 128) 0
_________________________________________________________________
dropout_4 (Dropout) (None, 128) 0
_________________________________________________________________
dense_6 (Dense) (None, 17) 2193
_________________________________________________________________
activation_192 (Activation) (None, 17) 0
=================================================================
Total params: 54,418,401
Trainable params: 54,418,401
Non-trainable params: 0
_________________________________________________________________

​如您所见,由于Flatten层的输出太大,Dense层的参数太多:425104 * 128 + 128 = 54413440,即仅一层就有 5400 万个参数(几乎是模型中所有参数的 99%)。那么,如何减少这个数字呢?您需要通过使用 stride 参数(我不推荐)或池化层(最好在每个卷积层之后)来减小卷积层的输出大小。让我们再添加两个池化层和一个转换层(当我们深入时,我什至增加了转换层中的过滤器数量,因为这通常是一件好事):

# 1st conv + pooling layer: 
model.add(Convolution2D(
nb_filters, (nb_conv, nb_conv),
padding="valid",
input_shape=(img_rows, img_cols, 3),
data_format='channels_last', ))
model.add(Activation('relu'))
model.add(MaxPooling2D(pool_size=(nb_pool, nb_pool), data_format="channels_last"))

# 2nd conv + pooling layer:
model.add(Convolution2D(nb_filters*2, (nb_conv, nb_conv), data_format='channels_last'))
model.add(Activation('relu'))
model.add(MaxPooling2D(pool_size=(nb_pool, nb_pool), data_format="channels_last"))

# 3rd conv + pooling layer:
model.add(Convolution2D(nb_filters*2, (nb_conv, nb_conv), data_format='channels_last'))
model.add(Activation('relu'))
model.add(MaxPooling2D(pool_size=(nb_pool, nb_pool), data_format="channels_last"))

# the rest is the same...

模型摘要输出:

_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
conv2d_197 (Conv2D) (None, 328, 328, 16) 448
_________________________________________________________________
activation_203 (Activation) (None, 328, 328, 16) 0
_________________________________________________________________
max_pooling2d_16 (MaxPooling (None, 164, 164, 16) 0
_________________________________________________________________
conv2d_198 (Conv2D) (None, 162, 162, 32) 4640
_________________________________________________________________
activation_204 (Activation) (None, 162, 162, 32) 0
_________________________________________________________________
max_pooling2d_17 (MaxPooling (None, 81, 81, 32) 0
_________________________________________________________________
conv2d_199 (Conv2D) (None, 79, 79, 32) 9248
_________________________________________________________________
activation_205 (Activation) (None, 79, 79, 32) 0
_________________________________________________________________
max_pooling2d_18 (MaxPooling (None, 39, 39, 32) 0
_________________________________________________________________
dropout_9 (Dropout) (None, 39, 39, 32) 0
_________________________________________________________________
flatten_4 (Flatten) (None, 48672) 0
_________________________________________________________________
dense_11 (Dense) (None, 128) 6230144
_________________________________________________________________
activation_206 (Activation) (None, 128) 0
_________________________________________________________________
dropout_10 (Dropout) (None, 128) 0
_________________________________________________________________
dense_12 (Dense) (None, 17) 2193
_________________________________________________________________
activation_207 (Activation) (None, 17) 0
=================================================================
Total params: 6,246,673
Trainable params: 6,246,673
Non-trainable params: 0
_________________________________________________________________

正如您所看到的,现在它的参数数量少于 650 万个,几乎是之前模型参数数量的九分之一。您甚至可以添加另一个池化层以进一步减少参数数量。但是,请记住,随着您的模型变得更深(即具有越来越多的层),您可能需要处理诸如 vanishing gradient 之类的问题。和 overfitting .

关于python - 减少卷积神经网络的内存需求,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/52463538/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com