gpt4 book ai didi

python - 如何在内存中复制 tensorflow 网络状态以便参数更新后检索?

转载 作者:太空宇宙 更新时间:2023-11-03 14:10:17 25 4
gpt4 key购买 nike

我有一个包含多个模块的 tensorflow 图,我想重用其中一个模块的之前的网络状态(参数更新之前)来评估下一个状态的输入 em>(参数更新后)。

示例

考虑一下玩具示例,我希望在时间步 t 处复制 network_B,以便在下一个训练步骤 t+1 中使用:

def network_A(x):
A1 = tf.matmul(x, A_W1) + A_b1
return tf.nn.relu(A1)

def network_B(x):
B1 = tf.matmul(x, B_W1) + B_b1
Z1 = tf.nn.relu(B1)
B2 = tf.matmul(Z1, B_W2) + B_b2
return B2

x = tf.placeholder(tf.float32, shape=[None, x_dim])
x_2 = network_A(x)

# Evaluate input x_2 with current state of network
y_hatB_current = network_B(x)

# Evaluate same input x_2 with past state of network
y_hatB_past = network_B_past(x) #

# Get some loss
loss = ...

然后,一旦两者都被评估,将网络的当前状态保存为新的过去状态,并仅优化当前状态:

# Save state of parameters
network_B_past = network_B # (How do I do this efficiently?)

# Optimize the current state
train = tf.train.AdamOptimizer().minimize(loss, var_list=current_vars)

详细信息

因此,在每个训练步骤中,应该存在两个版本的 network_B 可用于评估输入:

  1. network_B 在时间步 t-1(过去状态)
  2. network_B 在时间步t(当前状态)

在两个训练步骤之间存在参数更新,因此两者之间的权重应该略有不同,但其他方面应该相同。然后,在评估新输入后,当前状态将替换过去的状态,并且另一个训练步骤将更新网络。

我知道我可以在 tensorflow 中保存和重新加载检查点,但这对于我的用例来说似乎效率太低,因为它需要在每个训练步骤中发生。实现此网络克隆步骤以便我维护跨州持续存在的副本的有效方法是什么?

tensorflow 版本:1.5

最佳答案

我将使用函数 create_graph 在不同的变量范围下创建网络两次:一次用于当前,一次用于备份。请注意,这会使内存消耗加倍。

那么您所需要的只是一个自定义sync_op。 MWE 是

import tensorflow as tf


def copy_vars(src_scope, dst_scope):
src_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=src_scope)
dst_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=dst_scope)

update_op = []

for src_var in src_vars:
for dst_var in dst_vars:
if src_var.name.replace('%s' % src_scope, '') == dst_var.name.replace('%s' % dst_scope, ''):
assert dst_var.shape == src_var.shape
print(" copy: add assign {} -> {}".format(src_var.name, dst_var.name))
update_op.append(dst_var.assign(src_var))
return tf.group(update_op)


def create_graph(name, x, use_c=False, uses_gradient_updates=True):

var_setter = lambda x: x # noqa
if uses_gradient_updates:
var_setter = lambda x: tf.stop_gradient(x) # noqa

with tf.variable_scope(name, custom_getter=var_setter):
a = tf.Variable([1], dtype=tf.float32)
b = tf.Variable([1], dtype=tf.float32)
result = x + a + b
if use_c:
# create dummy variable just to show both graphs do not need to be exactly the same
c = tf.Variable([1], dtype=tf.float32)
return result, a, b

x = tf.placeholder(tf.float32)

c1, a1, b1 = create_graph('original', x, use_c=True)
c2, a2, b2 = create_graph('backup', x, use_c=False)

sync_op = copy_vars('original', 'backup')

with tf.Session() as sess:
sess.run(tf.global_variables_initializer())

print(sess.run([c1, c2], {x: 5})) # in sync

sess.run(a1.assign([3])) # update your graph either by tf.train.Adam or by:
print(sess.run([c1, c2], {x: 5})) # out of sync

sess.run(sync_op) # do syncing
print(sess.run([c1, c2], {x: 5})) # in sync

custom_getter 可以帮助防止渐变更新。

关于python - 如何在内存中复制 tensorflow 网络状态以便参数更新后检索?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/48555964/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com