gpt4 book ai didi

tensorflow - 如何将多个相同模型从保存文件加载到 Tensorflow 中的一个 session 中

转载 作者:行者123 更新时间:2023-12-03 22:18:59 25 4
gpt4 key购买 nike

在我的主代码中,我基于这样的配置文件创建了一个模型

with tf.variable_scope('MODEL') as topscope:
model = create_model(config_file)#returns input node, output node, and some other placeholders

此范围的名称在所有保存中均相同。

然后我定义了一个优化器和一个成本函数等(它们在这个范围之外)

然后我创建一个保护程序并保存它:
saver = tf.train.Saver(max_to_keep=10)
saver.save(sess, 'unique_name', global_step=t)

现在我已经创建并保存了 10 个不同的模型,我想像这样一次性加载它们:
models = []
for config, save_path in zip(configs, save_paths):
models.append(load_model(config, save_path))

并且能够运行它们并比较它们的结果、混合它们、平均等等。对于这些加载的模型,我不需要优化器槽变量。我只需要“模型”范围内的那些变量。

我需要创建多个 session 吗?

我该怎么做?我不知道从哪里开始。我可以从我的配置文件创建一个模型,然后使用这个相同的配置文件加载这个相同的模型,并像这样保存:
saver.restore(sess, save_path)

但是我如何加载不止一个?

编辑:我不知道这个词。我想制作一个网络集合。
问它但仍然没有回答的问题: How to create ensemble in tensorflow?

编辑 2:好的,这是我现在的解决方法:

这是我的主要代码,它创建一个模型,训练它并保存它:
import tensorflow as tf
from util import *

OLD_SCOPE_NAME = 'scope1'

sess = tf.Session()

with tf.variable_scope(OLD_SCOPE_NAME) as topscope:
model = create_model(tf, 6.0, 7.0)
sc_vars = get_all_variables_from_top_scope(tf, topscope)

print([v.name for v in sc_vars])

sess.run(tf.initialize_all_variables())
print(sess.run(model))

saver = tf.train.Saver()
saver.save(sess, OLD_SCOPE_NAME)

然后我运行此代码创建相同的模型,加载其检查点保存并重命名变量:
#RENAMING PART, different file
#create the same model as above here
import tensorflow as tf
from util import *
OLD_SCOPE_NAME = 'scope1'
NEW_SCOPE_NAME = 'scope2'

sess = tf.Session()

with tf.variable_scope(OLD_SCOPE_NAME) as topscope:
model = create_model(tf, 6.0, 7.0)
sc_vars = get_all_variables_from_top_scope(tf, topscope)

print([v.name for v in sc_vars])

saver = tf.train.Saver()
saver.restore(sess, OLD_SCOPE_NAME)
print(sess.run(model))


#assuming that we change top scope, not something in the middle, functionality can be added without much trouble I think
#not sure why I need to remove ':0' part, but it seems to work okay
print([NEW_SCOPE_NAME + v.name[len(OLD_SCOPE_NAME):v.name.rfind(':')] for v in sc_vars])
new_saver = tf.train.Saver(var_list={NEW_SCOPE_NAME + v.name[len(OLD_SCOPE_NAME):v.name.rfind(':')]:v for v in sc_vars})
new_saver.save(sess, NEW_SCOPE_NAME)

然后将此模型加载到包含附加变量和新名称的文件中:
import tensorflow as tf
from util import *
NEW_SCOPE_NAME = 'scope2'
sess = tf.Session()

with tf.variable_scope(NEW_SCOPE_NAME) as topscope:
model = create_model(tf, 5.0, 4.0)
sc_vars = get_all_variables_from_top_scope(tf, topscope)
q = tf.Variable(tf.constant(0.0, shape=[1]), name='q')

print([v.name for v in sc_vars])

saver = tf.train.Saver(var_list=sc_vars)
saver.restore(sess, NEW_SCOPE_NAME)
print(sess.run(model))

工具.py:
def get_all_variables_from_top_scope(tf, scope):
#scope is a top scope here, otherwise change startswith part
return [v for v in tf.all_variables() if v.name.startswith(scope.name)]

def create_model(tf, param1, param2):
w = tf.get_variable('W', shape=[1], initializer=tf.constant_initializer(param1))
b = tf.get_variable('b', shape=[1], initializer=tf.constant_initializer(param2))
y = tf.mul(w, b, name='mul_op')#no need to save this
return y

最佳答案

在概念层面:

  • 有两个独立的东西:图形和 session
  • 首先创建图形。它定义了您的模型。没有理由不能在一张图中存储多个模型。没关系。它还定义了变量,但实际上并不包含它们的状态
  • 在图表之后创建一个 session
  • 它是从图形创建的
  • 您可以从图表中创建任意数量的 session
  • 它保存图中不同变量的状态,即各种模型中的权重

  • 所以:
  • 当您只加载模型定义时,您只需要:一个或多个图表。一张图就够了
  • 当您加载模型的实际权重、学习的权重/参数时,您需要从图中为此创建一个 session 。一个 session 就足够了

  • 请注意,变量都有名称,并且它们必须是唯一的。您可以在图表中使用变量范围为它们指定唯一名称,例如:
    with tf.variable_scope("some_scope_name"):
    # created model nodes here...

    这将在 Tensorboard 图中很好地将您的节点组合在一起。

    好的,稍微重读一下你的问题。看起来您想一次保存/加载单个模型。

    保存/加载模型的参数/权重发生在 session 中, session 包含图中定义的每个变量的权重/参数。

    您可以通过名称引用这些变量,例如通过您在上面创建的范围,并将这些变量的子集保存到不同的文件等中。

    顺便说一句,也可以使用 session.run(...)获取权重/参数的值,作为 numpy 张量,然后您可以选择pickle,或者其他什么,如果您选择。

    关于tensorflow - 如何将多个相同模型从保存文件加载到 Tensorflow 中的一个 session 中,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/36828600/

    25 4 0
    Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
    广告合作:1813099741@qq.com 6ren.com