gpt4 book ai didi

c++ - c++ 中的 readNetFromTensorflow 错误

转载 作者:塔克拉玛干 更新时间:2023-11-03 07:05:43 27 4
gpt4 key购买 nike

我是深度学习的新手。在第一步中,我使用 keras 在 python 中创建和训练模型并通过以下代码卡住:

def export_model(MODEL_NAME, input_node_name, output_node_name):

tf.train.write_graph(K.get_session().graph_def, 'out', \
MODEL_NAME + '_graph.pbtxt')

tf.train.Saver().save(K.get_session(), 'out/' + MODEL_NAME + '.chkp')

freeze_graph.freeze_graph('out/' + MODEL_NAME + '_graph.pbtxt', None, \
False, 'out/' + MODEL_NAME + '.chkp', output_node_name, \
"save/restore_all", "save/Const:0", \
'out/frozen_' + MODEL_NAME + '.pb', True, "")

input_graph_def = tf.GraphDef()
with tf.gfile.Open('out/frozen_' + MODEL_NAME + '.pb', "rb") as f:
input_graph_def.ParseFromString(f.read())

output_graph_def = optimize_for_inference_lib.optimize_for_inference(
input_graph_def, [input_node_name], [output_node_name],
tf.float32.as_datatype_enum)

with tf.gfile.FastGFile('out/opt_' + MODEL_NAME + '.pb', "wb") as f:
f.write(output_graph_def.SerializeToString())

它的输出:

  • 检查点
  • Model.chkp.data-00000-of-00001
  • 模型.chkp.index
  • 模型.chkp.meta
  • 模型图.pbtxt
  • frozen_Model.pb
  • opt_Model.pb

当我想通过 readNetFromTensorflow 在 opencv c++ 中读取网络时:

String weights = "frozen_Model.pb";
String pbtxt = "Model_graph.pbtxt";
dnn::Net cvNet = cv::dnn::readNetFromTensorflow(weights, pbtxt);

这会出错:

OpenCV(4.0.0-pre) Error: Unspecified error (FAILED: ReadProtoFromBinaryFile(param_file, param). Failed to parse GraphDef file: frozen_Model.pb) in cv::dnn::ReadTFNetParamsFromBinaryFileOrDie, file D:\LIBS\OpenCV-4.00\modules\dnn\src\tensorflow\tf_io.cpp, line 44

OpenCV(4.0.0-pre) Error: Assertion failed (const_layers.insert(std::make_pair(name, li)).second) in cv::dnn::experimental_dnn_v4::`anonymous-namespace'::addConstNodes, file D:\LIBS\OpenCV-4.00\modules\dnn\src\tensorflow\tf_importer.cpp, line 555

如何修复这个错误?

最佳答案

Amin,我可以请你尝试在测试模式下保存图表吗:

K.backend.set_learning_phase(0)  # <--- This setting makes all the following layers work in test mode

model = Sequential(name = MODEL_NAME)
model.add(Conv2D(filters = 128, kernel_size = (5, 5), activation = 'relu',name = 'FirstLayerConv2D_No1',input_shape = (Width, Height, image_channel)))
...
model.add(Dropout(0.25))
model.add(Dense(100, activation = 'softmax', name = 'endNode'))

# Create a graph definition (with no weights)
sess = K.backend.get_session()
sess.as_default()
tf.train.write_graph(sess.graph.as_graph_def(), "", 'graph_def.pb', as_text=False)

然后用 freeze_graph.py 新创建的 graph_def.pb 卡住你的检查点文件脚本(不要忘记使用 --input_binary 标志)。

关于c++ - c++ 中的 readNetFromTensorflow 错误,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/50643009/

27 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com