gpt4 book ai didi

tensorflow - 如何在 tensorflow 中创建集成?

转载 作者:行者123 更新时间:2023-12-03 20:56:50 34 4
gpt4 key购买 nike

我正在尝试创建一个由许多训练有素的模型组成的集合。所有模型都有相同的图,只是权重不同。我正在使用 tf.get_variable 创建模型图.对于同一个图形架构,我有几个不同的检查点(具有不同的权重),我想为每个检查点制作一个实例模型。

如何在不覆盖先前加载的权重的情况下加载多个检查点?

当我用 tf.get_variable 创建我的图表时,我可以创建多个图形的唯一方法是传递参数 reuse = True .现在,如果我尝试在加载之前更改将构建方法包含在新范围内的图形变量的名称(因此它们无法与其他创建的图形共享),那么这将不起作用,因为新名称将与保存的名称不同重量,我将无法加载它。

最佳答案

这需要一些技巧。让我们保存几个简单的模型

#! /usr/bin/env python
# -*- coding: utf-8 -*-

import argparse
import tensorflow as tf


def build_graph(init_val=0.0):
x = tf.placeholder(tf.float32)
w = tf.get_variable('w', initializer=init_val)
y = x + w
return x, y


if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--init', help='dummy string', type=float)
parser.add_argument('--path', help='dummy string', type=str)
args = parser.parse_args()

x1, y1 = build_graph(args.init)

saver = tf.train.Saver()
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
print(sess.run(y1, {x1: 10})) # outputs: 10 + i

save_path = saver.save(sess, args.path)
print("Model saved in path: %s" % save_path)

# python ensemble.py --init 1 --path ./models/model1.chpt
# python ensemble.py --init 2 --path ./models/model2.chpt
# python ensemble.py --init 3 --path ./models/model3.chpt

这些模型产生“10 + i”的输出,其中 i=1、2、3。
请注意,此脚本多次创建、运行和保存相同的图形结构。加载这些值并单独恢复每个图形是民间传说,可以通过

#! /usr/bin/env python
# -*- coding: utf-8 -*-

import argparse
import tensorflow as tf


def build_graph(init_val=0.0):
x = tf.placeholder(tf.float32)
w = tf.get_variable('w', initializer=init_val)
y = x + w
return x, y


if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--path', help='dummy string', type=str)
args = parser.parse_args()

x1, y1 = build_graph(-5.)

saver = tf.train.Saver()
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())

saver.restore(sess, args.path)
print("Model loaded from path: %s" % args.path)

print(sess.run(y1, {x1: 10}))

# python ensemble_load.py --path ./models/model1.chpt # gives 11
# python ensemble_load.py --path ./models/model2.chpt # gives 12
# python ensemble_load.py --path ./models/model3.chpt # gives 13

这些再次产生预期的输出 11,12,13。现在的诀窍是从整体中为每个模型创建自己的范围,例如

def build_graph(x, init_val=0.0):
w = tf.get_variable('w', initializer=init_val)
y = x + w
return x, y


if __name__ == '__main__':
models = ['./models/model1.chpt', './models/model2.chpt', './models/model3.chpt']
x = tf.placeholder(tf.float32)
outputs = []
for k, path in enumerate(models):
# THE VARIABLE SCOPE IS IMPORTANT
with tf.variable_scope('model_%03i' % (k + 1)):
outputs.append(build_graph(x, -100 * np.random.rand())[1])

因此,每个模型都存在于不同的变量范围内,即。我们有变量 'model_001/w:0, model_002/w:0, model_003/w:0' 虽然它们有相似(不相同)的子图,但这些变量确实是不同的对象。现在,诀窍是管理两组变量(当前范围内的图形和检查点中的那些):

def restore_collection(path, scopename, sess):
# retrieve all variables under scope
variables = {v.name: v for v in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scopename)}
# retrieves all variables in checkpoint
for var_name, _ in tf.contrib.framework.list_variables(path):
# get the value of the variable
var_value = tf.contrib.framework.load_variable(path, var_name)
# construct expected variablename under new scope
target_var_name = '%s/%s:0' % (scopename, var_name)
# reference to variable-tensor
target_variable = variables[target_var_name]
# assign old value from checkpoint to new variable
sess.run(target_variable.assign(var_value))

完整的解决方案是

#! /usr/bin/env python
# -*- coding: utf-8 -*-

import numpy as np
import tensorflow as tf


def restore_collection(path, scopename, sess):
# retrieve all variables under scope
variables = {v.name: v for v in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scopename)}
# retrieves all variables in checkpoint
for var_name, _ in tf.contrib.framework.list_variables(path):
# get the value of the variable
var_value = tf.contrib.framework.load_variable(path, var_name)
# construct expected variablename under new scope
target_var_name = '%s/%s:0' % (scopename, var_name)
# reference to variable-tensor
target_variable = variables[target_var_name]
# assign old value from checkpoint to new variable
sess.run(target_variable.assign(var_value))


def build_graph(x, init_val=0.0):
w = tf.get_variable('w', initializer=init_val)
y = x + w
return x, y


if __name__ == '__main__':
models = ['./models/model1.chpt', './models/model2.chpt', './models/model3.chpt']
x = tf.placeholder(tf.float32)
outputs = []
for k, path in enumerate(models):
with tf.variable_scope('model_%03i' % (k + 1)):
outputs.append(build_graph(x, -100 * np.random.rand())[1])

with tf.Session() as sess:
sess.run(tf.global_variables_initializer())

print(sess.run(outputs[0], {x: 10})) # random output -82.4929
print(sess.run(outputs[1], {x: 10})) # random output -63.65792
print(sess.run(outputs[2], {x: 10})) # random output -19.888203

print(sess.run(W[0])) # randomly initialize value -92.4929
print(sess.run(W[1])) # randomly initialize value -73.65792
print(sess.run(W[2])) # randomly initialize value -29.888203

restore_collection(models[0], 'model_001', sess) # restore all variables from different checkpoints
restore_collection(models[1], 'model_002', sess) # restore all variables from different checkpoints
restore_collection(models[2], 'model_003', sess) # restore all variables from different checkpoints

print(sess.run(W[0])) # old values from different checkpoints: 1.0
print(sess.run(W[1])) # old values from different checkpoints: 2.0
print(sess.run(W[2])) # old values from different checkpoints: 3.0

print(sess.run(outputs[0], {x: 10})) # what we expect: 11.0
print(sess.run(outputs[1], {x: 10})) # what we expect: 12.0
print(sess.run(outputs[2], {x: 10})) # what we expect: 13.0

# python ensemble_load_all.py

现在有了输出列表,您可以在 TensorFlow 中对这些值求平均值或进行其他一些集成预测。

编辑 :
  • 使用 NumPy (npz) 将模型存储为 numpy 字典并加载这些值更容易,就像我在这里的回答一样:
    https://stackoverflow.com/a/50181741/7443104
  • 上面的代码只是说明了一个解决方案。它没有健全性检查(就像变量确实存在一样)。 try catch 可能会有所帮助。
  • 关于tensorflow - 如何在 tensorflow 中创建集成?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/35464652/

    34 4 0
    Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
    广告合作:1813099741@qq.com 6ren.com