gpt4 book ai didi

python - 在 Tensorflow 中使用 py_func - ValueError : callback pyfunc_0 is not found

转载 作者:太空宇宙 更新时间:2023-11-03 20:52:28 27 4
gpt4 key购买 nike

我尝试构建一个 tensorflow 模型 - 我在其中加载一个 pickle 文件和另一个模型作为 tensorflow 模型的一部分。该代码有两个部分,我创建模型(保存)并使用模型进行预测(加载)。我收到 ValueError: 找不到回调 pyfunc_0

.pb 文件本身非常小,因此看起来它没有将模型存储在 .pb 文件内的 .pickle 文件中。我不知道该怎么办。

保存部分

import tensorflow as tf
from keras import backend as K

from tensorflow.python.saved_model import builder as saved_model_builder
from tensorflow.python.saved_model import tag_constants, signature_constants, signature_def_utils_impl
from keras.callbacks import TensorBoard

from keras.models import Sequential
from keras.layers.core import Dense, Activation
from keras.optimizers import SGD
import numpy as np
import pickle

model_version = "465555564"
epoch = 100
tensorboard = TensorBoard(log_dir='./logs', histogram_freq = 0, write_graph = True, write_images = False)

sess = tf.Session()
K.set_session(sess)
K.set_learning_phase(0)

def my_func(x):
with open(PATH_TO_PICKLE, "rb") as f:
loadCF = pickle.load(f)
return np.float32(loadCF.predict([x])[1])

input = tf.placeholder(tf.float32)
y = tf.py_func(my_func, [input], tf.float32)

prediction_signature = tf.saved_model.signature_def_utils.predict_signature_def({"inputs": input}, {"prediction": y})
builder = saved_model_builder.SavedModelBuilder('./'+model_version)
legacy_init_op = tf.group(tf.tables_initializer(), name='legacy_init_op')
builder.add_meta_graph_and_variables(
sess, [tag_constants.SERVING],
signature_def_map={
signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:prediction_signature,
},
legacy_init_op=legacy_init_op)

builder.save()

加载部分

sess=tf.Session() 
signature_key = tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY
input_key = 'inputs'
output_key = 'prediction'

export_path = './465555564/'
meta_graph_def = tf.saved_model.loader.load(
sess,
[tf.saved_model.tag_constants.SERVING],
export_path)
signature = meta_graph_def.signature_def

x_tensor_name = signature[signature_key].inputs[input_key].name
y_tensor_name = signature[signature_key].outputs[output_key].name

x = sess.graph.get_tensor_by_name(x_tensor_name)
y = sess.graph.get_tensor_by_name(y_tensor_name)

y_out = sess.run(y, {x: [0.0, 3.0,2.0,1.0,1.0,0.0,1.0,3.0,1.0,0.000,0.000,0.000,0.000,0.000,0.000,1.000,0.000,0.000,0.000,0.000,0.000,
0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,
1.000,1.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,1.000,0.000,0.000,0.000,
0.000,0.000,1.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,
0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,
0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.000,0.0281021,
1.1674791,0.0772629,1.00919452640745377359,-0.40733408431212109191,0.27344889607694411460,-0.27692477736208176431,
0.90979100598229301067,0.30854060293899643330,-0.89088669667641318117,0.71015013257662451540,-0.45934534155660206034,
-1.5771756172180175781,-0.44342430101500618367,0.99046792752212953204,0.77406677189800476846,0.22008506072840341994,
-0.31012541014287209329,-0.30062459437047234223,-0.02684695402988129115,0.17956349253654479980,
-0.46235901945167118265,0.42958878223887747572,-0.44371617585420608521,-0.84945221741994225706,
0.63907705081833732219,-0.70754766008920144671,0.48411194566223358926,-0.12378847102324168350,
0.15848264263735878377]})
print(y_out)

最佳答案

tf.py_func不支持pb格式保存,请使用checkpoint格式

关于python - 在 Tensorflow 中使用 py_func - ValueError : callback pyfunc_0 is not found,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/56227140/

27 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com