gpt4 book ai didi

machine-learning - 输入 channel 数与 Keras 中过滤器的相应维度不匹配

转载 作者:行者123 更新时间:2023-11-30 09:27:09 25 4
gpt4 key购买 nike

我正在使用keras基于Resnet50构建模型,代码如下

input_crop = Input(shape=(3, 224, 224))

# extract feature from image crop
resnet = ResNet50(include_top=False, weights='imagenet')
for layer in resnet.layers: # set resnet as non-trainable
layer.trainable = False

crop_encoded = resnet(input_crop)

但是,我遇到了错误

'ValueError: number of input channels does not match corresponding dimension of filter, 224 != 3'

我该如何解决这个问题?

最佳答案

由于 Theano 和 TensorFlow 使用不同的图像格式,经常会产生此类错误 backends对于喀拉斯。就您而言,图像显然采用 channels_first 格式 (Theano),而您很可能使用需要 channels_last 格式的 TensorFlow 后端。

MNIST CNN example Keras 提供了一种很好的方法来使您的代码免受此类问题的影响,即同时适用于 Theano 和 TensorFlow 后端 - 以下是针对您的数据的调整:

from keras import backend as K

img_rows, img_cols = 224, 224

if K.image_data_format() == 'channels_first':
input_crop = input_crop.reshape(input_crop.shape[0], 3, img_rows, img_cols)
input_shape = (3, img_rows, img_cols)
else:
input_crop = input_crop.reshape(input_crop.shape[0], img_rows, img_cols, 3)
input_shape = (img_rows, img_cols, 3)

input_crop = Input(shape=input_shape)

关于machine-learning - 输入 channel 数与 Keras 中过滤器的相应维度不匹配,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/45909569/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com