gpt4 book ai didi

python - 在 TensorFlow 中重命名已保存模型的变量范围

转载 作者:太空狗 更新时间:2023-10-29 17:01:26 25 4
gpt4 key购买 nike

是否可以在 tensorflow 中重命名给定模型的变量范围?

例如,我根据教程为 MNIST 数字创建了一个逻辑回归模型:

with tf.variable_scope('my-first-scope'):
NUM_IMAGE_PIXELS = 784
NUM_CLASS_BINS = 10
x = tf.placeholder(tf.float32, shape=[None, NUM_IMAGE_PIXELS])
y_ = tf.placeholder(tf.float32, shape=[None, NUM_CLASS_BINS])

W = tf.Variable(tf.zeros([NUM_IMAGE_PIXELS,NUM_CLASS_BINS]))
b = tf.Variable(tf.zeros([NUM_CLASS_BINS]))

y = tf.nn.softmax(tf.matmul(x,W) + b)
cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y), reduction_indices=[1]))
saver = tf.train.Saver([W, b])

... # some training happens

saver.save(sess, 'my-model')

现在我想在 'my-first-scope' 变量范围内重新加载保存的模型,然后将所有内容再次保存到一个新文件并在 'my 的新变量范围下-second-scope'.

最佳答案

根据 keveman 的回答,我创建了一个 python 脚本,您可以执行该脚本来重命名任何 TensorFlow 检查点的变量:

https://gist.github.com/batzner/7c24802dd9c5e15870b4b56e22135c96

您可以替换变量名称中的子字符串并为所有名称添加前缀。调用脚本

python tensorflow_rename_variables.py --checkpoint_dir=path/to/dir

带有可选参数

--replace_from=substr --replace_to=substr --add_prefix=abc --dry_run

这是脚本的核心功能:

def rename(checkpoint_dir, replace_from, replace_to, add_prefix, dry_run=False):
checkpoint = tf.train.get_checkpoint_state(checkpoint_dir)
with tf.Session() as sess:
for var_name, _ in tf.contrib.framework.list_variables(checkpoint_dir):
# Load the variable
var = tf.contrib.framework.load_variable(checkpoint_dir, var_name)

# Set the new name
new_name = var_name
if None not in [replace_from, replace_to]:
new_name = new_name.replace(replace_from, replace_to)
if add_prefix:
new_name = add_prefix + new_name

if dry_run:
print('%s would be renamed to %s.' % (var_name, new_name))
else:
print('Renaming %s to %s.' % (var_name, new_name))
# Rename the variable
var = tf.Variable(var, name=new_name)

if not dry_run:
# Save the variables
saver = tf.train.Saver()
sess.run(tf.global_variables_initializer())
saver.save(sess, checkpoint.model_checkpoint_path)

例子:

python tensorflow_rename_variables.py --checkpoint_dir=path/to/dir --replace_from=scope1 --replace_to=scope1/model --add_prefix=abc/

将变量 scope1/Variable1 重命名为 abc/scope1/model/Variable1

关于python - 在 TensorFlow 中重命名已保存模型的变量范围,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/37086268/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com