gpt4 book ai didi

machine-learning - 使用预训练 ResNet50 网络的 OneClass SVM 模型

转载 作者:行者123 更新时间:2023-11-30 09:47:20 24 4
gpt4 key购买 nike

我正在尝试构建用于图像识别的 OneClass 分类器。我发现this文章,但因为我没有完整的源代码,所以我不太明白我在做什么。

X_train, X_test, y_train, y_test = train_test_split(x, y, random_state=42)

# X_train (2250, 200, 200, 3)
resnet_model = ResNet50(input_shape=(200, 200, 3), weights='imagenet', include_top=False)
features_array = resnet_model.predict(X_train)
# features_array (2250, 7, 7, 2048)
pca = PCA(svd_solver='randomized', n_components=450, whiten=True, random_state=42)
svc = SVC(kernel='rbf', class_weight='balanced')
model = make_pipeline(pca, svc)

param_grid = {'svc__C': [1, 5, 10, 50], 'svc__gamma': [0.0001, 0.0005, 0.001, 0.005]}
grid = GridSearchCV(model, param_grid)
grid.fit(X_train, y_train)

我有 2250 张图像(食物和非食物)200x200px 大小,我将此数据发送到 ResNet50 模型的预测方法。结果是 (2250, 7, 7, 2048) 张量,有人知道这个维度是什么意思吗?

当我尝试运行 grid.fit 方法时,出现错误:

ValueError: Found array with dim 4. Estimator expected <= 2.

最佳答案

这些是我可以得出的结论。

您将获得高于全局平均池化层的输出张量。 (请参阅 resnet_model.summary() 了解输入维度如何更改为输出维度)

要进行简单的修复,请在 resnet_model 之上添加平均池 2d 层。(这样输出形状就变成了(2250,1,1,2048))

resnet_model = ResNet50(input_shape=(200, 200, 3), weights='imagenet', include_top=False)
resnet_op = AveragePooling2D((7, 7), name='avg_pool_app')(resnet_model.output)
resnet_model = Model(resnet_model.input, resnet_op, name="ResNet")

这通常存在于 ResNet50 本身的源代码中。基本上,我们将 AveragePooling2D 层附加到 resnet50 模型中。最后一行将图层(第二行)和基线模型组合成模型对象。

现在输出维度 (feature_array) 将为 (2250, 1, 1, 2048)(因为添加了平均池化层)。

为了避免 ValueError,您应该将此 feature_array reshape 为 (2250, 2048)

feature_array = np.reshape(feature_array, (-1, 2048))

在问题中程序的最后一行,

grid.fit(X_train, y_train)

您已经适合 X_train(在本例中是图像)。这里正确的变量是features_array(这被认为是图像的摘要)。输入此行将纠正错误,

grid.fit(features_array, y_train)

要通过提取特征向量以这种方式进行更多微调,请查看 here (使用神经网络进行训练,而不是使用 PCA 和 SVM)。

希望这有帮助!!

关于machine-learning - 使用预训练 ResNet50 网络的 OneClass SVM 模型,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/50889313/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com