gpt4 book ai didi

python - 我对 Keras 中的batch_size理解正确吗?

转载 作者:行者123 更新时间:2023-11-30 08:57:16 24 4
gpt4 key购买 nike

我正在使用 Keras 的内置 inception_resnet_v2 来训练 CNN 来识别图像。训练模型时,我有一个 numpy 数据数组作为输入,输入形状为 (1000, 299, 299, 3),

 model.fit(x=X, y=Y, batch_size=16, ...) # Output shape `Y` is (1000, 6), for 6 classes

起初,当尝试预测时,我传入了形状为 (299, 299, 3) 的单个图像,但得到了错误

ValueError: Error when checking input: expected input_1 to have 4 dimensions, but got array with shape (299, 299, 3)

我用以下方法 reshape 了我的输入:

x = np.reshape(x, ((1, 299, 299, 3)))

现在,当我预测时,

y = model.predict(x, batch_size=1, verbose=0)

我没有收到错误。

我想确保我在训练和预测中正确理解batch_size。我的假设是:

1) 使用 model.fit,Keras 从输入数组中获取 batch_size 元素(在本例中,它一次处理我的 1000 个示例 16 个样本)

2) 使用 model.predict,我应该将输入 reshape 为单个 3D 数组,并且应该显式将 batch_size 设置为 1。

这些假设正确吗?

此外,向模型提供训练数据是否会更好(甚至可能),这样就不需要在预测之前进行这种 reshape ?感谢您帮助我学习这一点。

最佳答案

不,你的想法错了。 batch_size指定一次通过网络“转发”多少个数据示例(通常使用 GPU)。

默认情况下,该值设置为 32里面model.predict方法,但您可以以其他方式指定(就像您对 batch_size=1 所做的那样)。由于这个默认值,你得到了一个错误:

ValueError: Error when checking input: expected input_1 to have 4 dimensions, but got array with shape (299, 299, 3)

不应该以这种方式 reshape 您的输入,而是应该为其提供正确的批量大小。

假设,对于默认情况,您将传递形状为 (32, 299, 299, 3) 的数组,不同的类似batch_size ,例如与 batch_size=64此函数要求您传递形状 (64, 299, 299, 3 的输入.

编辑:

看来您需要将单个样本重新调整为批处理。我建议您使用np.expand_dims为了提高代码的可读性和可移植性,如下所示:

y = model.predict(np.expand_dims(x, axis=0), batch_size=1)

关于python - 我对 Keras 中的batch_size理解正确吗?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/55212576/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com