gpt4 book ai didi

tensorflow - keras.model.predict 引发 ValueError : Error when checking input

转载 作者:行者123 更新时间:2023-11-30 09:43:35 28 4
gpt4 key购买 nike

我在 MNIST 数据集上训练了基本的神经网络模型。这是训练的代码:(省略导入)

mnist = tf.keras.datasets.mnist

(x_train, y_train), (x_test, y_test) = mnist.load_data(path='mnist.npz')
x_train, x_test = x_train/255.0, x_test/255.0

#1st Define the model
model = tf.keras.models.Sequential([
tf.keras.layers.Flatten(input_shape = (28,28)), #input layer
tf.keras.layers.Dense(512, activation=tf.nn.relu), #main computation layer
tf.keras.layers.Dropout(0.2), #Dropout layer to avoid overfitting
tf.keras.layers.Dense(10, activation=tf.nn.softmax) #output layer / Softmax is a classifier AF
])

#2nd Compile the model
model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])

#3rd Fit the model
model.fit(x_train, y_train, epochs=5)

#4th Save the model
model.save('models/mnistCNN.h5')

#5th Evaluate the model
model.evaluate(x_test, y_test)

我想看看这个模型如何与我自己的输入一起工作,所以我在this post的帮助下编写了一个预测脚本。 。我的预测代码是:(省略导入)

model = load_model('models/mnistCNN.h5')

for i in range(3):
img = Image.open(str(i+1) + '.png').convert("L")
img = img.resize((28,28))
im2arr = np.array(img)
im2arr = im2arr/255
im2arr = im2arr.reshape(1, 28, 28, 1)
y_pred = model.predict(im2arr)
print('For Image',i+1,'Prediction = ',y_pred)

首先,我不明白这一行的目的:

im2arr = im2arr.reshape(1, 28, 28, 1)

如果有人能够阐明为什么这条线是必要的,那将会有很大的帮助。

其次,这一行抛出以下错误:

ValueError: Error when checking input: expected flatten_input to have 3 dimensions, but got array with shape (1, 28, 28, 1)

我在这里缺少什么?

最佳答案

第一个维度用于批量大小。它是由 keras.model 内部添加的。所以这一行只是将其添加到图像数组中。

im2arr = im2arr.reshape(1, 28, 28, 1)

您得到的错误是因为您用于训练的 mnist 数据集 中的单个示例的形状为 (28, 28),输入层也是如此。要消除此错误,您需要将此行更改为

im2arr = img.reshape((1, 28, 28))

关于tensorflow - keras.model.predict 引发 ValueError : Error when checking input,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/55679218/

28 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com