gpt4 book ai didi

tensorflow - CNN 的模型架构设计

转载 作者:行者123 更新时间:2023-11-30 09:58:24 27 4
gpt4 key购买 nike

我想为有 300 个类的数据集设计 CNN。我已经用以下模型对两个类(class)进行了测试。它具有良好的准确性。

model = Sequential([
Conv2D(16, 3, padding='same', activation='relu', input_shape=(IMG_HEIGHT, IMG_WIDTH ,3)),
MaxPooling2D(),
Conv2D(32, 3, padding='same', activation='relu'),
MaxPooling2D(),
Conv2D(64, 3, padding='same', activation='relu'),
MaxPooling2D(),
Flatten(),
Dense(512, activation='relu'),
Dense(1, activation='sigmoid')
])

但是我将类别数量增加为 5,准确率下降到 0.2 左右。如何设计 300 个类别的 CNN 架构?

最佳答案

为了对超过 2 个类别的数据集进行训练,您需要对最后一个 Dense 层使用 categorical_crossentropy 损失和 softmax 激活层。

最后一个 Dense 的神经元数量将决定您想要预测的类别数量,因此如果您有 300 个类别,则结果将如下所示:

[...]
MaxPooling2D(),
Flatten(),
Dense(512, activation='relu'),
Dense(300, activation='softmax')
])

此外,类越多,模型学习每个类的特征就越困难。因此,您需要增加其宽度(更多过滤器)、深度(更多卷积 block )或分辨率(更大的输入)。

设计 CNN 模型的一个好的开始是查看 VGG 架构及其卷积 block 。

希望它能帮助您获得一些直觉。

关于tensorflow - CNN 的模型架构设计,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/60077092/

27 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com