gpt4 book ai didi

python - 如何在自定义 python 运算符 (tf.py_func) 中更新模型(变量)?

转载 作者:太空宇宙 更新时间:2023-11-04 04:48:39 24 4
gpt4 key购买 nike

我需要在 python 中编写一个自定义的 Op,它将根据模型和另一个将更新模型的 op 生成输出。在下面的示例代码中,我有一个非常简单的缩放器模型,w(但实际上它将是一个 nxm 矩阵)。我想出了如何“读取”模型,如 custom_model_read_op 函数中所示(实际上要复杂得多)。但是,我如何创建类似的东西,以某种自定义的复杂方式(使用 custom_model_update_op)更新 w?我认为这是可能的,因为像 SGD 这样的 Optimizer ops 能够做到这一点。提前致谢!

import tensorflow as tf
import numpy

# Create a model
w = tf.Variable(numpy.random.randn(), name="weight")
X = tf.placeholder(tf.int32, shape=(), name="X")

def custom_model_read_op(i, w):
y = i*float(w)
return y
y = tf.py_func(custom_model_read_op, [X, w], [tf.float64], name="read_func")

def custom_model_update_op(i, w):


==> # How to update w (the model stored in a Variable above) based on the value of i and some crazy logic?

return 0
crazy_update = tf.py_func(custom_model_update_op, [X, w], [tf.int64], name="update_func")



with tf.Session() as sess:

tf.global_variables_initializer().run()

for i in range(10):
y_out, __ = sess.run([y, crazy_update], feed_dict={X: i})
print("y=", "{:.4f}".format(y_out[0]))

最佳答案

好吧,我不确定这是最好的方法,但它会在我需要时完成。我没有 py_func 发生 w 的更新,但我确实在 read_op 中更新它,将它作为返回传回值,最后使用 assign 函数在自定义操作之外修改它。如果任何 Tensorflow 专家可以确认这是一种很好的合法方式,我将不胜感激。

import tensorflow as tf
import numpy

# Create a model
w = tf.Variable(numpy.random.randn(), name="weight")
X = tf.placeholder(tf.int32, shape=(), name="X")

def custom_model_read_op(i, w):
y = i*float(w)
w = custom_model_update(w)
return y, w
y = tf.py_func(custom_model_read_op, [X, w], [tf.float64, tf.float64], name="read_func")

def custom_model_update(w):
# update w (the model stored in a Variable above) based on the vaue of i and some crazy logic
return w + 1

with tf.Session() as sess:

tf.global_variables_initializer().run()

for i in range(10):
y_out, w_modified = sess.run(y, feed_dict={X: i})
print("y=", "{:.4f}".format(y_out))
assign_op = w.assign(w_modified)
sess.run(assign_op)
print("w=", "{:.4f}".format(sess.run(w)))

关于python - 如何在自定义 python 运算符 (tf.py_func) 中更新模型(变量)?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/48931351/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com