gpt4 book ai didi

tensorflow - 在 tensorflow 检查点中修改张量的形状

转载 作者:行者123 更新时间:2023-12-02 02:58:03 27 4
gpt4 key购买 nike

我有一个 tensorflow 检查点,在使用常规例程 tf.train.Saver()saver.restore(session, 'my_checkpoint.ckpt').

但是,现在,我想修改网络的第一层以接受形状为 [200, 200, 1] 而不是 [200, 200, 10] 的输入]

为此,我想从[3, 3, 10, 32]修改第一层对应的tensor的形状(3x3 kernel,10个输入 channel ,32个输出 channel ) 到 [3, 3, 1, 32] 通过对第 3 个维度求和。

我该怎么做?

最佳答案

我找到了一种方法,但不是那么直接。给定一个检查点,我们可以将其转换为序列化的 numpy 数组(或我们可能认为适合保存 numpy 数组字典的任何其他格式),如下所示:

checkpoint = {}
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
saver.restore(sess, 'my_checkpoint.ckpt')
for x in tf.global_variables():
checkpoint[x.name] = x.eval()
np.save('checkpoint.npy', checkpoint)

可能会有一些异常需要处理,但让我们保持代码简单。

然后,我们可以对 numpy 数组进行任何我们喜欢的操作:

checkpoint = np.load('checkpoint.npy')
checkpoint = ...
np.save('checkpoint.npy', checkpoint)

最后,我们可以在构建图形后手动加载权重,如下所示:

with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
checkpoint = np.load('checkpoint.npy').item()
for key, data in checkpoint.iteritems():
var_scope = ... # to be extracted from key
var_name = ... #
with tf.variable_scope(var_scope, reuse=True):
var = tf.get_variable(var_name)
sess.run(var.assign(data))

如果有更直接的方法,我会洗耳恭听!

关于tensorflow - 在 tensorflow 检查点中修改张量的形状,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/48138041/

27 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com