gpt4 book ai didi

python - 运行时保存 TensorFlow 状态(权重)

转载 作者:行者123 更新时间:2023-11-28 18:57:26 25 4
gpt4 key购买 nike

我有一个运行拟合模型的 python 程序。我没有实现一个保护程序,所以我想知道是否有一种方法可以在拟合运行时直接从内存(也许从临时文件?)恢复权重

最佳答案

我认为这应该可以使用命令 tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)

在不保存模型的情况下提取/恢复模型权重的代码如下所示。

import tensorflow as tf
import numpy as np

X_ = tf.placeholder(tf.float64, [None, 5], name="Input")
Y_ = tf.placeholder(tf.float64, [None, 1], name="Output")

X = np.random.randint(1,10,[10,5])
Y = np.random.randint(0,2,[10,1])

with tf.variable_scope("LogReg"):
pred = tf.layers.dense(X_, 1, activation=tf.nn.sigmoid, name = 'fc1')
loss = tf.losses.mean_squared_error(labels=Y_, predictions=pred)
training_ops = tf.train.GradientDescentOptimizer(0.01).minimize(loss)

with tf.Session() as sess:

all_vars= tf.global_variables()
def get_var(name):
for i in range(len(all_vars)):
if all_vars[i].name.startswith(name):
return all_vars[i]
return None

sess.run(tf.global_variables_initializer())
for i in range(200):
sess.run([training_ops], feed_dict={X_: X,Y_: Y})
Weight_Vars = sess.run([tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)], feed_dict={X_: X,Y_: Y})
print(Weight_Vars)

关于python - 运行时保存 TensorFlow 状态(权重),我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/56687437/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com