gpt4 book ai didi

python - 使用张量输入时 Keras 模型预测会发生变化

转载 作者:行者123 更新时间:2023-12-01 03:21:50 24 4
gpt4 key购买 nike

我想使用来自 Keras 的预训练 Inception-V3 模型,与来自 Tensorflow 的输入管道配对(即通过张量提供网络的输入输入)。
这是我的代码:

import tensorflow as tf
from keras.preprocessing.image import load_img, img_to_array
from keras.applications.inception_v3 import InceptionV3, decode_predictions, preprocess_input
import numpy as np

img_sample_filename = 'my_image.jpg'
img = img_to_array(load_img(img_sample_filename, target_size=(299,299)))
img = preprocess_input(img)
img_tensor = tf.constant(img[None,:])

# WITH KERAS:
model = InceptionV3()
pred = model.predict(img[None,:])
pred = decode_predictions(np.asarray(pred)) #<------ correct prediction!
print(pred)

# WITH TF:
model = InceptionV3(input_tensor=img_tensor)
init = tf.global_variables_initializer()

with tf.Session() as sess:
from keras import backend as K
K.set_session(sess)

sess.run(init)
pred = sess.run([model.output], feed_dict={K.learning_phase(): 0})

pred = decode_predictions(np.asarray(pred)[0])
print(pred) #<------ wrong prediction!

哪里 my_image.jpg是我想要分类的任何图像。

如果我使用 keras 的 predict函数来计算预测,结果是正确的。但是,如果我从图像数组中创建一个张量并通过 input_tensor=... 将该张量提供给模型然后通过 sess.run([model.output], ...) 计算预测结果大错特错。

不同行为的原因是什么?我不能这样使用Keras网络吗?

最佳答案

最后,挖掘 InceptionV3代码,我发现了问题:sess.run(init)覆盖 InceptionV3 中加载的权重的构造函数。
我发现这个问题的 -dirty- 修复是在 sess.run(init) 之后重新加载权重.

from keras.applications.inception_v3 import get_file, WEIGHTS_PATH

with tf.Session() as sess:
from keras import backend as K
K.set_session(sess)

sess.run(init)
weights_path = get_file(
'inception_v3_weights_tf_dim_ordering_tf_kernels.h5',
WEIGHTS_PATH,
cache_subdir='models',
md5_hash='9a0d58056eeedaa3f26cb7ebd46da564')
model.load_weights(weights_path)
pred = sess.run([model.output], feed_dict={K.learning_phase(): 0})

备注 : get_file()的参数直接取自 InceptionV3的构造函数,在我的示例中,专门用于使用 image_data_format='channels_last' 恢复整个网络的权重。 .
我问了 this Github issue如果有更好的解决方法。如果我应该获得更多信息,我会更新这个答案。

关于python - 使用张量输入时 Keras 模型预测会发生变化,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/44627401/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com