gpt4 book ai didi

python - 更改 Estimator SessionRunHook 中的 tf.Variable 值

转载 作者:行者123 更新时间:2023-12-01 01:42:00 25 4
gpt4 key购买 nike

我有一个 tf.Estimator,其 model_fn 包含一个初始化为 1.0 的 tf.Variable。我想根据开发集的准确性在每个时期更改变量值。我实现了一个 SessionRunHook 来实现此目的,但是当我尝试更改该值时,我收到以下错误:

raise RuntimeError("图表已完成,无法修改。")

这是 Hook 的代码:

    class DynamicWeightingHook(tf.train.SessionRunHook):
def __init__(self, epoch_size, gamma_value):
self.gamma = gamma_value
self.epoch_size = epoch_size
self.steps = 0

def before_run(self, run_context):
self.steps += 1

def after_run(self, run_context, run_values):
if self.steps % epoch_size == 0: # epoch
with tf.variable_scope("lambda_scope", reuse=True):
lambda_tensor = tf.get_variable("lambda_value")
tf.assign(lambda_tensor, self.gamma_value)
self.gamma_value += 0.1

我知道当我运行钩子(Hook)时,图表已完成,但我想知道是否有其他方法可以在训练期间使用 Estimator API 更改 model_fn 图表中的变量值。

最佳答案

现在设置 Hook 的方式实际上是在每次 session 运行后尝试创建新的变量/操作。相反,您应该预先定义 tf.assign 操作并将其传递给钩子(Hook),以便它可以在必要时自行运行该操作,或者在钩子(Hook)的 __init__ 中定义分配操作>。您可以通过 run_context 参数访问 after_run 内的 session 。所以像

class DynamicWeightingHook(tf.train.SessionRunHook):
def __init__(self, epoch_size, gamma_value, lambda_tensor):
self.gamma = gamma_value
self.epoch_size = epoch_size
self.steps = 0
self.update_op = tf.assign(lambda_tensor, self.gamma_placeholder)

def before_run(self, run_context):
self.steps += 1

def after_run(self, run_context, run_values):
if self.steps % epoch_size == 0: # epoch
run_context.session.run(self.update_op)
self.gamma += 0.1

这里有一些注意事项。首先,我不确定是否可以使用这样的 Python 整数进行 tf.assign 操作,即一旦更改 gamma 后它是否会正确更新。如果这不起作用,您可以尝试以下操作:

class DynamicWeightingHook(tf.train.SessionRunHook):
def __init__(self, epoch_size, gamma_value, lambda_tensor):
self.gamma = gamma_value
self.epoch_size = epoch_size
self.steps = 0
self.gamma_placeholder = tf.placeholder(tf.float32, [])
self.update_op = tf.assign(lambda_tensor, self.gamma_placeholder)

def before_run(self, run_context):
self.steps += 1

def after_run(self, run_context, run_values):
if self.steps % epoch_size == 0: # epoch
run_context.session.run(self.update_op, feed_dict={self.gamma_placeholder: self.gamma})
self.gamma += 0.1

在这里,我们使用一个额外的占位符来始终将“当前”gamma 传递给分配操作。

其次,由于钩子(Hook)需要访问变量,因此您需要在模型函数内定义钩子(Hook)。您可以将此类 Hook 传递到 EstimatorSpec 中的训练过程(请参阅 here )。

关于python - 更改 Estimator SessionRunHook 中的 tf.Variable 值,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/51784864/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com