gpt4 book ai didi

python - TensorFlow:卡住图形后准确性急剧下降?

转载 作者:行者123 更新时间:2023-11-28 17:17:42 25 4
gpt4 key购买 nike

在服务图表卡住后,准确率急剧下降是否很常见?在使用预训练的 inception-resnet-v2 对花卉数据集进行训练和评估期间,我的准确度为 98-99%,正确预测的概率为 90+%。然而,在卡住我的图表并再次预测之后,我的模型并不那么准确,正确的标签只能以 30-40% 的置信度进行预测。


  1. 检查点文件
  2. model.ckpt.index 文件
  3. model.ckpt.meta 文件
  4. model.ckpt 文件
  5. 一个 graph.pbtxt 文件。

因为我无法运行位于 tensorflow repository on GitHub 中的官方卡住图文件(我想这是因为我在训练后有一个 pbtxt 文件而不是 pb 文件),我正在重用 this tutorial 中的代码相反。


import os, argparse

import tensorflow as tf
from tensorflow.python.framework import graph_util

dir = os.path.dirname(os.path.realpath(__file__))

def freeze_graph(model_folder, input_checkpoint):
# We retrieve our checkpoint fullpath
checkpoint = tf.train.get_checkpoint_state(model_folder)
# input_checkpoint = checkpoint.model_checkpoint_path

# We precise the file fullname of our freezed graph
absolute_model_folder = "/".join(input_checkpoint.split('/')[:-1])
output_graph = absolute_model_folder + "/frozen_model.pb"

# Before exporting our graph, we need to precise what is our output node
# This is how TF decides what part of the Graph he has to keep and what part it can dump
# NOTE: this variable is plural, because you can have multiple output nodes
output_node_names = "InceptionResnetV2/Logits/Predictions"

# We clear devices to allow TensorFlow to control on which device it will load operations
clear_devices = True

# We import the meta graph and retrieve a Saver
saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=clear_devices)

# We retrieve the protobuf graph definition
graph = tf.get_default_graph()
input_graph_def = graph.as_graph_def()

# We start a session and restore the graph weights
with tf.Session() as sess:
saver.restore(sess, input_checkpoint)

# We use a built-in TF helper to export variables to constants
output_graph_def = graph_util.convert_variables_to_constants(
sess, # The session is used to retrieve the weights
input_graph_def, # The graph_def is used to retrieve the nodes
output_node_names.split(",") # The output node names are used to select the usefull nodes

# Finally we serialize and dump the output graph to the filesystem
with tf.gfile.GFile(output_graph, "wb") as f:
print("%d ops in the final graph." % len(output_graph_def.node))

if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--model_folder", type=str, help="Model folder to export")
parser.add_argument("--input_checkpoint", type = str, help = "Input checkpoint name")
args = parser.parse_args()

freeze_graph(args.model_folder, args.input_checkpoint)


import tensorflow as tf
from scipy.misc import imread, imresize
import numpy as np

img = imread("./dandelion.jpg")
img = imresize(img, (299,299,3))
img = img.astype(np.float32)
img = np.expand_dims(img, 0)

labels_dict = {0:'daisy', 1:'dandelion',2:'roses', 3:'sunflowers', 4:'tulips'}

#Define the filename of the frozen graph
graph_filename = "./frozen_model.pb"

#Create a graph def object to read the graph
with tf.gfile.GFile(graph_filename, "rb") as f:
graph_def = tf.GraphDef()

#Construct the graph and import the graph from graphdef
with tf.Graph().as_default() as graph:

#We define the input and output node we will feed in
input_node = graph.get_tensor_by_name('import/batch:0')
output_node = graph.get_tensor_by_name('import/InceptionResnetV2/Logits/Predictions:0')

with tf.Session() as sess:
predictions =, feed_dict = {input_node: img})
print predictions
label_predicted = np.argmax(predictions[0])

print 'Predicted Flower:', labels_dict[label_predicted]
print 'Prediction probability:', predictions[0][label_predicted]


2017-04-11 17:38:21.722217: I tensorflow/stream_executor/cuda/] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2017-04-11 17:38:21.722608: I tensorflow/core/common_runtime/gpu/] Found device 0 with properties:
name: GeForce GTX 860M
major: 5 minor: 0 memoryClockRate (GHz) 1.0195
pciBusID 0000:01:00.0
Total memory: 3.95GiB
Free memory: 3.42GiB
2017-04-11 17:38:21.722624: I tensorflow/core/common_runtime/gpu/] DMA: 0
2017-04-11 17:38:21.722630: I tensorflow/core/common_runtime/gpu/] 0: Y
2017-04-11 17:38:21.722642: I tensorflow/core/common_runtime/gpu/] Creating TensorFlow device (/gpu:0) -> (device: 0, name: GeForce GTX 860M, pci bus id: 0000:01:00.0)
2017-04-11 17:38:22.183204: I tensorflow/compiler/xla/service/] platform CUDA present with 1 visible devices
2017-04-11 17:38:22.183232: I tensorflow/compiler/xla/service/] platform Host present with 8 visible devices
2017-04-11 17:38:22.184007: I tensorflow/compiler/xla/service/] XLA service 0xb85a1c0 executing computations on platform Host. Devices:
2017-04-11 17:38:22.184022: I tensorflow/compiler/xla/service/] StreamExecutor device (0): <undefined>, <undefined>
2017-04-11 17:38:22.184140: I tensorflow/compiler/xla/service/] platform CUDA present with 1 visible devices
2017-04-11 17:38:22.184149: I tensorflow/compiler/xla/service/] platform Host present with 8 visible devices
2017-04-11 17:38:22.184610: I tensorflow/compiler/xla/service/] XLA service 0xb631ee0 executing computations on platform CUDA. Devices:
2017-04-11 17:38:22.184620: I tensorflow/compiler/xla/service/] StreamExecutor device (0): GeForce GTX 860M, Compute Capability 5.0
[[ 0.1670652 0.46482906 0.12899996 0.12481128 0.11429448]]
Predicted Flower: dandelion
Prediction probability: 0.464829

潜在的问题来源:我首先使用 TF 0.12 训练我的模型,但我相信它与我现在使用的 Tf 1.01 版本兼容。作为安全预防措施,我将我的文件升级到 TF 1.01 并重新训练模型以获得新的检查点文件集(具有相同的精度),然后使用这些检查点文件进行卡住。我从源代码编译了我的tensorflow。问题是因为我使用的是 pbtxt 文件而不是 pb 文件吗?我不知道如何通过训练我的模型获得 pb 文件。



我推荐你使用InceptionResnet V2中默认的预处理功能。

下面,我将发布一段代码,该代码采用图像路径(JPG 或 PNG)并返回经过预处理的图像。您可以修改它以使其接收一批图像。这不是专业代码。它需要一些优化。但是,它运行良好。


def load_img(path_img):
Load an image to tensorflow
:param path_img: image path on the disk
:return: 3D tensorflow image
filename_queue = tf.train.string_input_producer([path_img]) # list of files to read

reader = tf.WholeFileReader()
key, value =

my_img = tf.image.decode_image(value) # use png or jpg decoder based on your files.

init_op = tf.global_variables_initializer()
with tf.Session() as sess:

# Start populating the filename queue.

coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord=coord)

for i in range(1): # length of your filename list
image = my_img.eval() # here is your image Tensor :)

# Image.fromarray(np.asarray(image)).show()


return image


def preprocess(image, height, width,
central_fraction=0.875, scope=None):
"""Prepare one image for evaluation.

If height and width are specified it would output an image with that size by
applying resize_bilinear.

If central_fraction is specified it would cropt the central fraction of the
input image.

image: 3-D Tensor of image. If dtype is tf.float32 then the range should be
[0, 1], otherwise it would converted to tf.float32 assuming that the range
is [0, MAX], where MAX is largest positive representable number for
int(8/16/32) data type (see `tf.image.convert_image_dtype` for details)
height: integer
width: integer
central_fraction: Optional Float, fraction of the image to crop.
scope: Optional scope for name_scope.
3-D float Tensor of prepared image.

image = tf.image.convert_image_dtype(image, dtype=tf.float32)
# Crop the central region of the image with an area containing 87.5% of
# the original image.
if central_fraction:
image = tf.image.central_crop(image, central_fraction=central_fraction)

if height and width:
# Resize the image to the specified height and width.
image = tf.expand_dims(image, 0)
image = tf.image.resize_bilinear(image, [height, width],
image = tf.squeeze(image, [0])
image = tf.subtract(image, 0.5)
image = tf.multiply(image, 2.0)
return image

最后,对于我的情况,我必须将处理后的张量转换为 numpy 数组:

image = tf.Session().run(image)


persistent_sess = tf.Session(graph=graph)  # , config=sess_config)

input_node = graph.get_tensor_by_name('prefix/batch:0')
output_node = graph.get_tensor_by_name('prefix/InceptionResnetV2/Logits/Predictions:0')

predictions =, feed_dict={input_node: [image]})
label_predicted = np.argmax(predictions[0])

关于python - TensorFlow:卡住图形后准确性急剧下降?,我们在Stack Overflow上找到一个类似的问题:

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号