gpt4 book ai didi

tensorflow - 在 TensorFlow 中保存特定权重

转载 作者:行者123 更新时间:2023-12-03 00:11:25 25 4
gpt4 key购买 nike

在我的神经网络中,我创建了一些 tf.Variable 对象,如下所示:

weights = {
'wc1_0': tf.Variable(tf.random_normal([5, 5, 3, 64])),
'wc1_1': tf.Variable(tf.random_normal([5, 5, 3, 64]))
}
biases = {
'bc1_0': tf.Variable(tf.constant(0.0, shape=[64])),
'bc1_1': tf.Variable(tf.constant(0.0, shape=[64]))
}

在特定次数的迭代之后,如何在不保存其他变量的情况下保存权重偏差中的变量?

最佳答案

在 TensorFlow 中保存变量的标准方法是使用 tf.train.Saver 目的。默认情况下,它会保存问题中的所有变量(即 tf.all_variables() 的结果),但您可以通过传递 var_list 有选择地保存变量。 tf.train.Saver 的可选参数构造函数:

weights = {
'wc1_0': tf.Variable(tf.random_normal([5, 5, 3, 64])),
'wc1_1': tf.Variable(tf.random_normal([5, 5, 3, 64]))
}
biases = {
'bc1_0': tf.Variable(tf.constant(0.0, shape=[64])),
'bc1_1': tf.Variable(tf.constant(0.0, shape=[64]))
}

# Define savers for explicit subsets of the variables.
weights_saver = tf.train.Saver(var_list=weights)
biases_saver = tf.train.Saver(var_list=biases)

# ...
# You need a TensorFlow Session to save variables.
sess = tf.Session()
# ...

# ...then call the following methods as appropriate:
weights_saver.save(sess) # Save the current value of the weights.
biases_saver.save(sess) # Save the current value of the biases.
<小时/>

请注意,如果您将字典传递给 tf.train.Saver构造函数(例如问题中的 weights 和/或 biases 字典),TensorFlow 将使用字典键(例如 'wc1_0' )作为任何检查点文件中相应变量的名称它创造或消耗。

默认情况下,或者如果您传递 tf.Variable 列表对象到构造函数,TensorFlow 将使用 tf.Variable.name属性(property)代替。

传递字典使您能够在提供不同 Variable.name 的模型之间共享检查点每个变量的属性。仅当您想将创建的检查点与其他模型一起使用时,此详细信息才重要。

关于tensorflow - 在 TensorFlow 中保存特定权重,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/39450046/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com