gpt4 book ai didi

python - 在 TF 估计器训练 Hook 中设置变量?

转载 作者:行者123 更新时间:2023-12-01 09:00:15 24 4
gpt4 key购买 nike

train tf.estimator.Estimator的功能具有以下签名:

train(
input_fn,
hooks=None,
steps=None,
max_steps=None,
saving_listeners=None
)

我正在训练一个网络,我需要根据相当复杂的算法的结果每隔几步手动设置一些变量,这无法在图中实现。是否可以在钩子(Hook)中设置变量的值?有谁知道这方面的示例代码吗?

为了不浪费资源,我不需要在每个训练步骤都调用钩子(Hook)。有没有办法指定我的钩子(Hook)只应每 N 步骤调用一次?当然,我可以自己在钩子(Hook)中保留一个计数器,并在我的算法不应该运行时返回,但这似乎应该是可配置的。

最佳答案

是的,这应该是可能的!我不确切知道这个变量存在于哪个范围内,也不知道您如何引用它,所以我只是假设您知道它的名称。我基本上是从我的其他答案中窃取代码 here .

只需在训练循环之前创建一个钩子(Hook):

class VariableUpdaterHook(tf.train.SessionRunHook):
def __init__(self, frequency, variable_name):
# variable name should be like: parent/scope/some/path/variable_name:0
self._global_step_tensor = None
self.variable = None
self.frequency = frequency
self.variable_name = variable_name

def after_create_session(self, session, coord):
self.variable = session.graph.get_tensor_by_name(self.variable_name)

def begin(self):
self._global_step_tensor = tf.train.get_global_step()

def after_run(self, run_context, run_values):
global_step = run_context.session.run(self._global_step_tensor)
if global_step % self.frequency == 0:
new_variable_value = complicated_algorithm(...)
assign_op = self.variable.assign(new_variable_value)
run_context.session.run(assign_op)

我认为不值得花精力研究另一种方法来避免每次迭代后的调用,因为它们非常便宜。所以要走的路就是你建议的。

注意:我没有时间对此进行调试,因为我目前没有用例。但我希望你能明白。

关于python - 在 TF 估计器训练 Hook 中设置变量?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/52512282/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com