gpt4 book ai didi

python-2.7 - 将 tensorflow 模型保存到文件

转载 作者:行者123 更新时间:2023-12-01 18:42:49 26 4
gpt4 key购买 nike

我创建了一个 tensorflow 模型,我想将其保存到文件中,以便以后可以对其进行预测。特别是,我需要保存:

  • input_placeholder
    (= tf.placeholder(tf.float32, [无, iVariableLen]))
  • 解决方案空间
    (= tf.nn.sigmoid(tf.matmul(input_placeholder,weight_variable)+bias_variable))
  • session
    (= tf.Session())

我尝试过使用 pickle,它适用于 sklearn 二值化器等其他对象,但不适用于上面的对象,因此我在底部收到错误。

我如何 pickle :

import pickle
with open(sModelSavePath, 'w') as fiModel:
pickle.dump(dModel, fiModel)

其中dModel是一个字典,其中包含我想要保留的所有对象,我用它来进行拟合。

关于如何 pickle tensorflow 对象有什么建议吗?

错误消息:

pickle.dump(dModel, fiModel)
...
raise TypeError, "can't pickle %s objects" % base.__name__
TypeError: can't pickle module objects

最佳答案

我解决这个问题的方法是 pickleing Sklearn 对象,例如二值化器,并使用 tensorflow's inbuilt save functions对于实际模型:

保存 tensorflow 模型:
1)像平常一样构建模型
2) 使用tf.train.Saver()保存 session 。例如:

oSaver = tf.train.Saver()

oSess = oSession
oSaver.save(oSess, sModelPath) #filename ends with .ckpt

3) 这会将该 session 中的所有可用变量等保存到其变量名称中。

加载 tensorflow 模型:
1)整个流程需要重新初始化。换句话说,需要声明变量、权重、偏差、损失函数等,然后通过将 tf.initialize_all_variables() 传递给 oSession.run()< br/>2) 现在需要将该 session 传递给加载程序。我抽象了流程,所以我的加载器看起来像这样:

dAlg = tf_training_algorithm()  #defines variables etc and initializes session

oSaver = tf.train.Saver()
oSaver.restore(dAlg['oSess'], sModelPath)

return {
'oSess': dAlg['oSess'],
#the other stuff I need from my algorithm, like my solution space etc
}

3)预测所需的所有对象都需要从初始化中取出,在我的例子中位于 dAlg

PS:像这样 pickle :

with open(sSavePathFilename, 'w') as fiModel:
pickle.dump(dModel, fiModel)

with open(sFilename, 'r') as fiModel:
dModel = pickle.load(fiModel)

关于python-2.7 - 将 tensorflow 模型保存到文件,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/38000180/

26 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com