gpt4 book ai didi

python - 将 bool 从 feed_dict 传递给函数不起作用

转载 作者:行者123 更新时间:2023-12-01 07:35:57 24 4
gpt4 key购买 nike

我正在尝试将 feed_dict value bool 传递给函数

def sum(a, b, flag = True, msg1= "Sum", msg2= "Multiply "):

if (flag is True):
print(msg1)
vtotal = tf.add(a,b)
else:
print(msg2)
vtotal = tf.multiply(a,b)

return vtotal

当我调用 sum(a,b) 函数时,使用默认值 flag = True 进行处理

但是当我将该函数调用为

sum(a, b, flag):

我从 feed_dict 中获取 flag 的值,例如

output = sess.run(total,feed_dict = {a: a_arr, b: b_arr, flag: True})

它不将值设为 True,而是执行函数的 else 部分

完整代码如下:请帮忙解释为什么会发生这种情况。

def initialize_placeholders():
a = tf.placeholder(tf.float32,[3,None],name="a")
b = tf.placeholder(tf.float32,[3,None],name ="b")
flag = tf.placeholder(tf.bool, name="flag")

return a, b, flag

def sum(a, b, flag = True, msg1= "Sum", msg2= "Multiply "):

if (flag is True):
print(msg1)
vtotal = tf.add(a,b)
else:
print(msg2)
vtotal = tf.multiply(a,b)

return vtotal

def model(a_arr,b_arr):
#print(a_arr)
#print(b_arr)
tf.reset_default_graph()
a, b ,flag= initialize_placeholders()
total = sum(a,b,flag)

init = tf.global_variables_initializer()
print(flag)

with tf.Session() as sess:
sess.run(init)
output = sess.run(total,feed_dict = {a: a_arr, b: b_arr, flag: True})
print(flag)
unv = sess.run(tf.report_uninitialized_variables())
sess.close()
return output, unv

a_arr = np.arange(6)
a_arr = a_arr.reshape(3,2)
b_arr = np.array([2,4,6,8,10,12])
b_arr = b_arr.reshape(3,2)
output , unv = model(a_arr,b_arr)
print(output)
print(unv)

最佳答案

您不能在常规条件 Python 语句中使用 TensorFlow 值(除非您使用类似 AutoGraph 的内容)。你可以用 tf.cond 做你想做的事像这样:

def sum(a, b, flag=True):
flag = tf.convert_to_tensor(flag)
return tf.cond(flag, lambda: tf.add(a, b), lambda: tf.multiply(a, b))

您还可以使其变得更复杂一点,以保存 tf.condflag的值预先固定时进行操作。例如,您可以有这样的内容:

def sum(a, b, flag = True, msg1= "Sum", msg2= "Multiply "):
true_fn = lambda: tf.add(a, b)
false_fn = lambda: tf.multiply(a, b)
if flag is True:
return true_fn()
elif flag is False:
return false_fn()
else: # Use TensorFlow conditional
flag = tf.convert_to_tensor(flag)
return tf.cond(flag, true_fn, false_fn)

我删除了 print 指令,因为它们不能直接在 TensorFlow 条件中使用,但您仍然可以拥有 tf.print如果您想在执行图形时查看打印的消息,请执行以下操作。

关于python - 将 bool 从 feed_dict 传递给函数不起作用,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/56994607/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com