gpt4 book ai didi

Tensorflow:可以阻止 tf.where 的一个分支执行吗?

转载 作者:行者123 更新时间:2023-12-04 17:51:17 27 4
gpt4 key购买 nike

我正在研究编码器-解码器设置。我希望能够运行一次编码器,然后运行多个解码器。我想出的解决方案是为解码器提供一个 TF 条件节点(使用 tf.where),它包含编码器的最终隐藏状态(在这种情况下,当我请求解码器输出时,TF 将运行编码器),或带有编码器存储结果的占位符(在这种情况下,理论上 TF 不需要运行编码器)。

这是代码的相关部分:

encoder_state = tf.where(gen_math_ops.greater_equal(branching_points, 0), encoder_state,
rnn.static_rnn(encoder_cell, encoder_inputs, dtype=dtype)[1])

因为我没有从这种方法中获得加速,所以我很确定它不起作用并且 tf.where 的两个分支每次都由 TF 运行,即使它只需要从占位符读取。

有没有什么方法可以使用 tf.where 使其不运行编码器?我查看了该方法的描述,但我不确定是否始终计算两个分支,我在这个问题上看到了相互矛盾的信息。

谢谢!

最佳答案

tf.cond()当您想要推迟执行其中一个分支直到对谓词求值时,可以使用函数。

encoder_state = tf.cond(
tf.greater_equal(branching_points, 0),
lambda: encoder_state,
lambda: tf.nn.static_rnn(encoder_cell, encoder_inputs, dtype=dtype)[1])

关于Tensorflow:可以阻止 tf.where 的一个分支执行吗?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/44569219/

27 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com