gpt4 book ai didi

python - 尝试在 tensorflow 的 map_fn 中使用 if

转载 作者:太空宇宙 更新时间:2023-11-03 11:17:19 25 4
gpt4 key购买 nike

我正在尝试使用 tensorflow 中的 map_fn 将转换应用于列向量,但它不起作用。

对于下面的列向量:

elems = np.array([[1.0], [2.0], [3.0]])

当我这样做时:

tf_m = tf.map_fn(lambda x: x + 1.0, elems)
with tf.Session() as sess:
res = sess.run(tf_m)
print(str(res))

我得到了我期望的结果,即这个列向量:

[[2.]
[3.]
[4.]]

但是,当我这样做时:

tf_m2 = tf.map_fn(lambda x: x+1 if x % 2 > 0 else x, elems)
with tf.Session() as sess:
res = sess.run(tf_m2)
print(str(res))

代码失败,出现以下异常:

TypeError: Using a tf.Tensor as a Python bool is not allowed. Use if t is not None: instead of if t: to test if a tensor is defined, and use TensorFlow ops such as tf.cond to execute subgraphs conditioned on the value of a tensor.

我试过打印 x 的类型,它是一个形状为 (1,) 的张量。所以,看起来正在发生的事情是,这些值没有作为标量值传递到 lambda 中,而是作为形状为 (1,) 的张量; % 被广播,产生另一个形状为 (1,) 的张量,但该张量不能再应用 >= 运算符。

有没有办法让它工作?有没有办法获得一个我可以应用 >= 运算符的实际 标量?如果没有,是否有我可以使用的 map_fn 的有效替代方案?

(我查看了 tf.cond,但我不清楚如何在这种情况下使用它。据我了解,tf.cond 生成一个操作,而不是可调用的,所以我将如何使用它从 map_fn 应用的 lambda 中?)

最佳答案

您可以使用 tf.map_fn 来完成和 tf.cond像这样:

elems_shape = tf.shape(elems)
elems_flat = tf.reshape(elems, [-1])
tf_m2_flat = tf.map_fn(lambda x: tf.cond(x % 2 > 0, lambda: x + 1, lambda: x), elems_flat)
tf_m2 = tf.reshape(tf_m2_flat, elems_shape)

但您也可以简单地使用 tf.where像这样:

tf_m2 = tf.where(elems % 2 > 0, elems + 1, elems)

关于python - 尝试在 tensorflow 的 map_fn 中使用 if,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/49649186/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com