gpt4 book ai didi

python - 在 Keras 上正确构建模型

转载 作者:行者123 更新时间:2023-11-30 09:32:06 25 4
gpt4 key购买 nike

我是神经网络和 Keras 的新手,我想构建一个 CNN 来预测图像的某些值。 (这三个值预测图像顶部模糊的大小、长度和宽度)。所有 3 个值的范围都是从 0 到 1,而且我有一个很大的数据集。

不过,我不太确定如何构建 CNN 来做到这一点,因为到目前为止我构建的所有原型(prototype)代码都给出了格式 [1.,0.,0.] 的预测 而不是每个值的 0 到 1 之间的范围。最重要的是,尽管改变了 SGD 优化器中的轮数和衰减值,但我的损失函数根本没有任何变化。你能告诉我哪里出错了吗?这是我到目前为止所拥有的:

images, labels = load_dataset("images")   # function that loads images
images = np.asarray(images) # images are flattened 424*424 arrays (grayscale)
labels = np.asarray(labels) # Lables are 3-arrays, each value is float from 0-1

# I won't write this part but here I split into train_imgs and test_imgs

model = keras.Sequential()
# explicitly define SGD so that I can change the decay rate
sgd = keras.optimizers.SGD(lr=0.01, decay=1e-6, momentum=0.9, nesterov=True)

model.add(keras.layers.Dense(32, input_shape=(424*424,) ))
model.add(keras.layers.Activation('relu'))
model.add(keras.layers.Dense(3, activation='softmax'))

model.compile(loss='mean_squared_error',optimizer=sgd)
# note: I also tried defining a weighted binary crossentropy but it changed nothing

checkpoint_name = 'Weights-{epoch:03d}--{val_loss:.5f}.hdf5'
checkpoint = ModelCheckpoint(checkpoint_name, monitor='val_loss', verbose = 0, save_best_only = True, mode ='auto')
callbacks_list = [checkpoint]

model.fit(train_imgs, train_labls, epochs=20, batch_size=32, validation_split = 0.2, callbacks=callbacks_list)

predictions = model.predict(test_imgs) # make predictions on same test set!

现在我知道我遗漏了 dropout 层,但我希望 CNN 过度拟合我的数据,此时我只想让它做任何事情。当我对同一组图像进行预测时,我希望得到准确的结果,不是吗?我不太确定我错过了什么。谢谢您的帮助!

最佳答案

首先,替换 'softmax''sigmoid' .

Sigmoid 将使三个输出的范围从 0 到 1。另请注意,softmax 是用于分类的。它尝试仅最大化三个值之一,并且三个值之和始终为 1。

第二,如果你的损失被完全卡住,问题可能出在'relu' (relu 有一个没有梯度的恒定零区域)。您可以替换 'relu ' 与另一个东西,例如 'sigmoid''tanh' ,或者您也可以添加 BatchNormalization() relu 之前的层。

作为初学者的选择,我总是更喜欢使用 optimizer='adam' ,这通常比 SGD 快得多,并且您不需要太关心学习率(当然高级模型和最佳结果可能需要调整)。

关于python - 在 Keras 上正确构建模型,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/53749895/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com