gpt4 book ai didi

python - 无法弄清楚如何在 Keras 的 Conv2D 层中为我自己的数据集定义 input_shape

转载 作者:行者123 更新时间:2023-11-28 22:19:17 25 4
gpt4 key购买 nike

长话短说

我在定义输入形状时遇到这些错误

ValueError: Error when checking input: expected conv2d_1_input to have 4 dimensions, but got array with shape (4000, 20, 20)

ValueError: Input 0 is incompatible with layer conv2d_1: expected ndim=4, found ndim=5

长显式版本:

我正在使用不同的 Keras NN 尝试对我自己的数据集进行分类。

到目前为止,我的 ANN 成功了,但我的 CNN 遇到了问题。

数据集

Complete Code

数据集由指定大小并用 0 填充的矩阵组成,矩阵包含指定大小并用 1 填充的子矩阵。子矩阵是可选的,目标是训练神经网络预测矩阵是否包含子矩阵。为了使其更难检测,我在矩阵中添加了各种类型的噪声。

这是一张单独矩阵的图片,黑色部分是 0,白色部分是 1。图像的像素与矩阵中的条目之间存在 1:1 的对应关系。

enter image description here

我使用 numpy savetxt 和 loadtxt 将它们保存在文本中。然后看起来像这样:

#________________Array__Info:__(4000, 20, 20)__________
#________________Entry__Number__1________
0 0 1 1 0 1 0 0 0 0 0 0 0 0 0 0 1 0 0 0
0 0 0 1 0 0 0 1 0 0 1 0 0 0 0 0 0 1 0 1
0 0 1 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0
0 0 0 1 1 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0
0 0 0 0 0 0 1 0 0 0 0 0 0 0 1 0 1 0 0 1
0 0 1 1 0 0 0 0 0 0 0 1 0 0 1 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 1 0 1 0 1 0 0 0 0
0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 1 0 0 0 0 1 0 0 0 0 0 0 0 0 0 1
0 0 0 0 0 0 0 0 0 1 0 1 0 0 0 0 1 1 1 0
0 0 0 0 0 0 0 0 0 0 1 0 1 0 0 0 0 0 0 1
0 0 0 1 0 0 0 1 0 1 0 0 0 0 0 1 0 0 0 0
0 0 0 0 0 0 0 0 0 1 1 0 0 1 0 0 0 1 1 1
0 0 0 0 0 1 0 0 0 0 0 0 0 1 0 0 0 0 0 1
0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0
0 0 1 1 0 1 0 0 1 0 0 0 1 0 0 0 0 0 0 0
0 1 0 1 0 1 0 0 0 0 0 0 0 1 0 0 0 0 0 1
1 0 1 0 0 0 0 0 0 0 0 0 0 1 1 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0
0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 1 0 0
#________________Entry__Number__2________
0 0 0 0 0 0 0 0 0 0 1 1 1 1 0 0 0 0 0 0
1 1 0 1 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 0 0 0 1
1 0 0 0 1 0 0 1 0 0 0 0 0 0 0 0 1 1 0 0
0 1 0 1 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0
0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 0
0 0 1 0 0 0 1 0 0 1 0 0 0 0 0 1 0 0 0 1
0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 1 0 1 1 0
1 0 1 0 0 1 0 1 0 1 0 0 0 0 1 1 1 0 0 1
0 0 1 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0
1 0 0 0 1 1 0 0 0 0 1 0 0 1 0 0 0 1 0 0
0 0 0 1 0 0 0 1 0 0 0 0 0 0 0 1 0 0 0 1
0 0 0 0 0 1 1 0 0 0 0 1 0 1 0 0 0 0 0 0
0 0 1 1 0 0 0 0 0 0 0 1 1 1 1 1 0 1 0 0
0 0 0 0 0 0 0 1 1 0 1 1 1 1 1 1 0 0 0 1
0 1 0 0 0 0. . . . . . (and so on)

Complete Dataset

CNN代码

Github

代码:(忽略导入)

# data

inputData = dsg.loadDataset("test_input.txt")
outputData = dsg.loadDataset("test_output.txt")
print("the size of the dataset is: ", inputData.shape, " of type: ", type(inputData))


# parameters

# CNN

cnn = Sequential()

cnn.add(Conv2D(32, (3, 3), input_shape = inputData.shape, activation = 'relu'))

cnn.add(MaxPooling2D(pool_size = (2, 2)))

cnn.add(Flatten())

cnn.add(Dense(units=64, activation='relu'))

cnn.add(Dense(units=1, activation='sigmoid'))

cnn.compile(optimizer = "adam", loss = 'binary_crossentropy', metrics = ['accuracy'])

cnn.summary()

cnn.fit(inputData,
outputData,
epochs=100,
validation_split=0.2)

问题:

我得到这个输出错误信息

Using TensorFlow backend.
the size of the dataset is: (4000, 20, 20) of type: <class 'numpy.ndarray'>
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
conv2d_1 (Conv2D) (None, 3998, 18, 32) 5792
_________________________________________________________________
max_pooling2d_1 (MaxPooling2 (None, 1999, 9, 32) 0
_________________________________________________________________
flatten_1 (Flatten) (None, 575712) 0
_________________________________________________________________
dense_1 (Dense) (None, 64) 36845632
_________________________________________________________________
dense_2 (Dense) (None, 1) 65
=================================================================
Total params: 36,851,489
Trainable params: 36,851,489
Non-trainable params: 0
_________________________________________________________________
Traceback (most recent call last):
File "D:\GOOGLE DRIVE\School\sem-2-2018\BSP2\BiCS-BSP-2\CNN\matrixCNN.py", line 47, in <module>
validation_split=0.2)
File "C:\Code\Python\lib\site-packages\keras\models.py", line 963, in fit
validation_steps=validation_steps)
File "C:\Code\Python\lib\site-packages\keras\engine\training.py", line 1637, in fit
batch_size=batch_size)
File "C:\Code\Python\lib\site-packages\keras\engine\training.py", line 1483, in _standardize_user_data
exception_prefix='input')
File "C:\Code\Python\lib\site-packages\keras\engine\training.py", line 113, in _standardize_input_data
'with shape ' + str(data_shape))
ValueError: Error when checking input: expected conv2d_1_input to have 4 dimensions, but got array with shape (4000, 20, 20)

我真的不知道怎么解决这个问题。我查看了 documentation of Conv2D,它说要把它放在这样的形式中:(批处理、高度、宽度、 channel )。就我而言(我认为):

input_shape=(4000, 20, 20, 1)

,因为我有 4000 个只有 1 和 0 的 20*20 矩阵

但随后我收到此错误消息:

Using TensorFlow backend.
the size of the dataset is: (4000, 20, 20) of type: <class 'numpy.ndarray'>
Traceback (most recent call last):
File "D:\GOOGLE DRIVE\School\sem-2-2018\BSP2\BiCS-BSP-2\CNN\matrixCNN.py", line 30, in <module>
cnn.add(Conv2D(32, (3, 3), input_shape = (4000, 12, 12, 1), activation = 'relu'))
File "C:\Code\Python\lib\site-packages\keras\models.py", line 467, in add
layer(x)
File "C:\Code\Python\lib\site-packages\keras\engine\topology.py", line 573, in __call__
self.assert_input_compatibility(inputs)
File "C:\Code\Python\lib\site-packages\keras\engine\topology.py", line 472, in assert_input_compatibility
str(K.ndim(x)))
ValueError: Input 0 is incompatible with layer conv2d_1: expected ndim=4, found ndim=5

我应该以哪种确切的形状将数据传递到 CNN?

所有文件都可用 here感谢您的宝贵时间。

最佳答案

您的 CNN 期望形状为 (num_samples, 20, 20, 1) ,而您的数据格式为 (num_samples, 20, 20) .

由于您只有 1 个 channel ,您可以将数据 reshape 为 (4000, 20, 20, 1)

inputData = inputData.reshape(-1, 20, 20, 1)

如果你想在模型内部进行 reshape ,你可以添加一个 Reshape层。作为您的第一层:

model.add(Reshape(input_shape = (20, 20), target_shape=(20, 20, 1)))

关于python - 无法弄清楚如何在 Keras 的 Conv2D 层中为我自己的数据集定义 input_shape,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/49843113/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com