gpt4 book ai didi

python - 通过输入值来更新变量

转载 作者:太空宇宙 更新时间:2023-11-03 15:35:12 24 4
gpt4 key购买 nike

我有一个 TensorFlow 模型,我希望能够选择检索和更新权重(即在检查点机制之外)。由于我可能会多次执行此操作,因此我不想在每次执行此操作时都将节点添加到图表中,而是希望有一些可以调用的操作来执行此操作。我的想法是让 tf.assign 操作对更新的变量和值使用相同的变量,并在 feed_dict 中提供新的权重值;所以像这样:

weights_assigns = [tf.assign(v, v) for v in tf.trainable_variables()]
weights_update = tf.group(*weights_assigns)

# Now to update the weights
weights = [...] # List of new weight values
feed_dict = {v: w for v, w in zip(tf.trainable_variables(), weights)}
tf.run(weights_update_op, feed_dict=feed_dict)

在我看来,这应该将 feed_dict 中传递的值作为变量的当前值,然后通过 tf.assign 操作存储它们。但是,这不起作用,并且给我带来了一些关于意外值类型的奇怪错误。

我当前的替代方案是使用一些辅助节点(变量或占位符),并在赋值操作中用作值:

weights_updates = [tf.placeholder(v.dtype, v.get_shape()) for v in tf.trainable_variables()]
weights_assigns = [tf.assign(v, u) for v, u in zip(tf.trainable_variables(), weights_updates)]
weights_update_op = tf.group(*weights_assigns)

# Now to update the weights
weights = [...] # List of new weight values
feed_dict = {u: w for u, w in zip(weights_updates, weights)}
tf.run(weights_update_op, feed_dict=feed_dict)

这真的是唯一的方法吗?或者还有其他一些我没有看到的明显方式吗?

最佳答案

好吧,发现有一个load()方法:

with tf.Graph().as_default():
w = tf.get_variable('weights', shape=[3, 3],
initializer=tf.random_uniform_initializer(dtype=tf.float32))
init = tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init)
print('initial W:')
print(sess.run(w))
new_vals = np.reshape(np.arange(9, dtype=np.float32), (3,3))
w.load(new_vals)
print('updated W:')
print(sess.run(w))

也许这有帮助?

关于python - 通过输入值来更新变量,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/42575947/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com