gpt4 book ai didi

python - 为什么 Python 中的 CNN 比 Matlab 中的表现差很多?

转载 作者:行者123 更新时间:2023-12-04 15:33:38 25 4
gpt4 key购买 nike

我在 Matlab 2019b 中训练了一个 CNN,它可以进行二元分类。当这个 CNN 在测试数据集中进行测试时,它获得了大约 95% 的准确率。我用了exportONNXNetwork函数,以便我可以在 Tensorflow、Keras 中实现我的 CNN。这是我用来在 keras 中使用 ONNX 文件的代码:

import onnx
from onnx_tf.backend import prepare
import numpy as np
from numpy import array
from IPython.display import display
from PIL import Image

onnx_model = onnx.load("model.onnx")
tf_rep = prepare(onnx_model)
img = Image.open("image.jpg").resize((224,224))
img = array(img).reshape(1,3,224,224)
img = img.astype(np.uint8)

classification = tf_rep.run(img)
print(classification)

当这段 Python 代码在相同 测试数据集上进行测试时,它几乎将所有东西都归类为 0 类,只有少数情况属于 1 .我不确定为什么会这样。

最佳答案

乍一看,我认为您需要排列图像轴而不是 reshape :

img = Image.open("image.jpg").resize((224,224))
img = array(img).transpose(2, 0, 1)
img = np.expand_dims(img, 0)

您从 PIL 获得的图像采用 channel 最后格式,即形状为 (height, width, channels) 的张量,在本例中为 (224, 224, 3)。您的模型需要 channel 优先格式的输入,即形状为 (channels, height, width) 的张量,在本例中为 (3, 224, 224)

您需要将最后一个轴移到前面。如果你使用 reshape,NumPy 将以 C 顺序遍历数组(最后一个轴索引变化最快),这意味着你的图像最终会被打乱。这在示例中更容易理解:

>>> img = np.arange(48).reshape(4, 4, 3)
>>> img[0, 0, :]
array([0, 1, 2])

(0, 0)像素的RGB值为(0, 1, 2)。如果您使用 np.transpose(),则会保留:

>>> img.transpose(2, 0, 1)[:, 0, 0]
array([0, 1, 2])

如果你使用 reshape,你的图像会被打乱:

>>> img.reshape(3, 224, 224)[:, 0, 0]
array([0, 16, 32])

关于python - 为什么 Python 中的 CNN 比 Matlab 中的表现差很多?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/60517818/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com