gpt4 book ai didi

python - 获取keras中所有已知vgg-16类的列表

转载 作者:太空狗 更新时间:2023-10-30 00:05:01 25 4
gpt4 key购买 nike

我使用来自 Keras 的预训练 VGG-16 模型。

到目前为止我的工作源代码是这样的:

from keras.applications.vgg16 import VGG16
from keras.preprocessing.image import load_img
from keras.preprocessing.image import img_to_array
from keras.applications.vgg16 import preprocess_input
from keras.applications.vgg16 import decode_predictions

model = VGG16()

print(model.summary())

image = load_img('./pictures/door.jpg', target_size=(224, 224))
image = img_to_array(image) #output Numpy-array

image = image.reshape((1, image.shape[0], image.shape[1], image.shape[2]))

image = preprocess_input(image)
yhat = model.predict(image)

label = decode_predictions(yhat)
label = label[0][0]

print('%s (%.2f%%)' % (label[1], label[2]*100))

我发现该模型接受了 1000 个类的训练。是否有可能获得该模型训练的类(class)列表?打印出所有预测标签不是一种选择,因为只返回了 5 个。

提前致谢

最佳答案

您可以使用 decode_predictions 并在 top=1000 参数中传递类别总数(只有它的默认值为 5)。

或者您可以看看 Keras 如何在内部执行此操作:它下载文件 imagenet_class_index.json(通常将其缓存在 ~/.keras/models/ 中)。这是一个包含所有类标签的简单 json 文件。

关于python - 获取keras中所有已知vgg-16类的列表,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/47474869/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com