gpt4 book ai didi

tensorflow - `get_variable()` 无法识别 tf.estimator 的现有变量

转载 作者:行者123 更新时间:2023-12-04 15:50:17 24 4
gpt4 key购买 nike

此问题已被问到 here ,不同的是我的问题集中在Estimator .

一些上下文:我们已经使用 estimator 训练了一个模型,并获得了在 Estimator input_fn 中定义的一些变量,该函数将数据预处理为批处理。现在,我们正在转向预测。在预测过程中,我们使用相同的 input_fn读入和处理数据。 但是出错说变量(word_embeddings)不存在 (变量存在于 chkp 图中),这是 input_fn 中的相关代码位:

with tf.variable_scope('vocabulary', reuse=tf.AUTO_REUSE):
if mode == tf.estimator.ModeKeys.TRAIN:
word_to_index, word_to_vec = load_embedding(graph_params["word_to_vec"])
word_embeddings = tf.get_variable(initializer=tf.constant(word_to_vec, dtype=tf.float32),
trainable=False,
name="word_to_vec",
dtype=tf.float32)
else:
word_embeddings = tf.get_variable("word_to_vec", dtype=tf.float32)

基本上,当它处于预测模式时, else被调用以在检查点加载变量。未能识别该变量表明 a) 范围使用不当; b) 图未恢复。只要 reuse,我认为范围在这里并不重要设置正确。

我怀疑这是因为图形尚未在 input_fn 恢复。阶段。通常,通过调用 saver.restore(sess, "/tmp/model.ckpt") 来恢复图形。 reference .估算器调查 source code没有让我了解与恢复有关的任何信息,最好的方法是 MonitoredSession,它是培训的包装器。它已经从原始问题中延伸了很多,对我是否走在正确的道路上没有信心,如果有人有任何见解,我会在这里寻求帮助。

我的问题的一行摘要:如何在 tf.estimator 内恢复图形, 通过 input_fnmodel_fn ?

最佳答案

嗨,我认为您的错误仅仅是因为您没有在 tf.get_variable (at predict) 中指定形状,似乎即使要恢复变量,您也需要指定形状。

我用一个简单的线性回归估计器进行了以下测试,它只需要预测 x + 5

def input_fn(mode):
def _input_fn():
with tf.variable_scope('all_input_fn', reuse=tf.AUTO_REUSE):
if mode == tf.estimator.ModeKeys.TRAIN:
var_to_follow = tf.get_variable('var_to_follow', initializer=tf.constant(20))
x_data = np.random.randn(1000)
labels = x_data + 5
return {'x':x_data}, labels
elif mode == tf.estimator.ModeKeys.PREDICT:
var_to_follow = tf.get_variable("var_to_follow", dtype=tf.int32, shape=[])
return {'x':[0,10,100,var_to_follow]}
return _input_fn

featcols = [tf.feature_column.numeric_column('x')]
model = tf.estimator.LinearRegressor(featcols, './outdir')

这段代码工作得很好,const 的值是 20,为了好玩,在我的测试集中使用它来确认:p

但是,如果您删除 shape=[] ,它会中断,您还可以提供另一个初始值设定项,例如 tf.constant(500) 并且一切都会起作用并且将使用 20 。

通过运行
model.train(input_fn(tf.estimator.ModeKeys.TRAIN), max_steps=10000)


preds = model.predict(input_fn(tf.estimator.ModeKeys.PREDICT))
print(next(preds))

您可以将图形可视化,您会看到 a) 范围是正常的,并且 b) 图形已恢复。

希望这会帮助你。

关于tensorflow - `get_variable()` 无法识别 tf.estimator 的现有变量,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/53480116/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com