gpt4 book ai didi

tensorflow - 我可以从估算器中获取 tensorflow session 吗?

转载 作者:行者123 更新时间:2023-12-04 17:38:46 25 4
gpt4 key购买 nike

我正在使用 tf.estimator 的 LinearRegressor 并想将我的学习率衰减(最初是指数衰减)更改为使用损失的衰减。但是为此,我需要将评估损失传递给学习率衰减张量的一些占位符,并且在这一步中,我需要 tf.session。

我尝试了 tf.get_default_session() 来获取估算器创建的 session ,但是该 session 具有估算器使用的不同图表。


def my_decay(learning_rate, global_step, decay_step, loss, decay_rate):
# If loss is not reduced, than decay with decay_rate.

loss = tf.placeholder(tf.float32)
estimator = tf.estimator.LinearRegressor(
feature_columns=feature_columns,
optimizer==lambda: tf.train.FtrlOptimizer(
learning_rate=my_decay(learning_rate=0.1,
global_step=tf.get_global_step(), decay_step=10000,
loss=loss, decay_rate=0.96)),
config=sess_config
)

for _ in range(n_epoches):
metrics = tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec)
session.run(loss.assign(metrics['loss']))

使用上面的代码,我需要从估算器中获取session。有什么办法可以得到这个吗?

提前致谢!

最佳答案

像这样的事情的预期解决方案是子类化 tf.train.SessionRunHook并覆盖 before_run 方法以返回合适的 tf.train.SessionRunArgs .这将允许您在训练时提供值并将提取添加到 session.run 调用。您的类必须在调用之间携带对占位符和 loss 状态的引用。

然后您只需实例化该类并将 Hook 添加到 estimator.train 调用中的 hooks 参数,或者在本例中为您的 train_spec .如果您希望使用评估损失而不是训练损失,那么这可以通过向 eval_spec 添加另一个钩子(Hook)来实现,该钩子(Hook)在 after_run 方法中读取值。

关于tensorflow - 我可以从估算器中获取 tensorflow session 吗?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/55513636/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com