gpt4 book ai didi

python - Tensorflow 和 cifar 10,测试单个图像

转载 作者:太空宇宙 更新时间:2023-11-03 12:03:22 26 4
gpt4 key购买 nike

我试图使用来自 tensorflow 的 cifar-10 预测单个图像的类别。

我找到了这段代码,但是失败并出现了这个错误:

赋值要求两个张量的形状匹配。 lhs 形状= [18,384] rhs 形状= [2304,384]我知道这是因为批处理的大小只有 1。(使用 expand_dims 我创建了一个假批处理。)

但我不知道如何解决这个问题?

我到处搜索,但没有解决方案..提前致谢!

from PIL import Image
import tensorflow as tf
from tensorflow.models.image.cifar10 import cifar10
width = 24
height = 24

categories = ["airplane","automobile","bird","cat","deer","dog","frog","horse","ship","truck" ]

filename = "path/to/jpg" # absolute path to input image
im = Image.open(filename)
im.save(filename, format='JPEG', subsampling=0, quality=100)
input_img = tf.image.decode_jpeg(tf.read_file(filename), channels=3)
tf_cast = tf.cast(input_img, tf.float32)
float_image = tf.image.resize_image_with_crop_or_pad(tf_cast, height, width)
images = tf.expand_dims(float_image, 0)
logits = cifar10.inference(images)
_, top_k_pred = tf.nn.top_k(logits, k=5)
init_op = tf.initialize_all_variables()
with tf.Session() as sess:
saver = tf.train.Saver()
ckpt = tf.train.get_checkpoint_state('/tmp/cifar10_train')
if ckpt and ckpt.model_checkpoint_path:
print("ckpt.model_checkpoint_path ", ckpt.model_checkpoint_path)
saver.restore(sess, ckpt.model_checkpoint_path)
else:
print('No checkpoint file found')
exit(0)
sess.run(init_op)
_, top_indices = sess.run([_, top_k_pred])
for key, value in enumerate(top_indices[0]):
print (categories[value] + ", " + str(_[0][key]))

编辑

我试图放置一个占位符,第一个形状为 None,但出现此错误:必须完全定义新变量 (local3/weights) 的形状,而不是 (?, 384)。

现在我真的迷路了..这是新代码:

from PIL import Image
import tensorflow as tf
from tensorflow.models.image.cifar10 import cifar10
import itertools
width = 24
height = 24

categories = [ "airplane","automobile","bird","cat","deer","dog","frog","horse","ship","truck" ]

filename = "toto.jpg" # absolute path to input image
im = Image.open(filename)
im.save(filename, format='JPEG', subsampling=0, quality=100)
x = tf.placeholder(tf.float32, [None, 24, 24, 3])
init_op = tf.initialize_all_variables()
with tf.Session() as sess:
# Restore variables from training checkpoint.
input_img = tf.image.decode_jpeg(tf.read_file(filename), channels=3)
tf_cast = tf.cast(input_img, tf.float32)
float_image = tf.image.resize_image_with_crop_or_pad(tf_cast, height, width)
images = tf.expand_dims(float_image, 0)
i = images.eval()
print (i)
sess.run(init_op, feed_dict={x: i})
logits = cifar10.inference(x)
_, top_k_pred = tf.nn.top_k(logits, k=5)
variable_averages = tf.train.ExponentialMovingAverage(
cifar10.MOVING_AVERAGE_DECAY)
variables_to_restore = variable_averages.variables_to_restore()
saver = tf.train.Saver(variables_to_restore)
ckpt = tf.train.get_checkpoint_state('/tmp/cifar10_train')
if ckpt and ckpt.model_checkpoint_path:
print("ckpt.model_checkpoint_path ", ckpt.model_checkpoint_path)
saver.restore(sess, ckpt.model_checkpoint_path)
else:
print('No checkpoint file found')
exit(0)
_, top_indices = sess.run([_, top_k_pred])
for key, value in enumerate(top_indices[0]):
print (categories[value] + ", " + str(_[0][key]))

最佳答案

我认为这是因为通过 tf.Variabletf.get_variable 获取的变量必须具有完整定义的形状。您可以检查您的代码并给出完整定义的形状。

关于python - Tensorflow 和 cifar 10,测试单个图像,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/40266275/

26 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com