gpt4 book ai didi

python - Tensorflow,在给定条件下更改张量值

转载 作者:行者123 更新时间:2023-12-01 01:27:38 25 4
gpt4 key购买 nike

我正在将 numpy 代码转换为 Tensorflow。

它有以下行:

netout[..., 5:] *= netout[..., 5:] > obj_threshold

这不是相同的 Tensorflow 语法,我无法找到具有相同行为的函数。

首先我尝试了:

netout[..., 5:] * netout[..., 5:] > obj_threshold

但是返回的是一个 bool 张量。在本例中,我希望 obj_threshold 以下的所有值均为 0。

最佳答案

如果您只想将 obj_threshold 以下的所有值设置为 0,您可以这样做:

netout = tf.where(netout > obj_threshold, netout, tf.zeros_like(netout))

或者:

netout = netout * tf.cast(netout > obj_threshold, netout.dtype)

但是,您的情况有点复杂,因为您只希望更改影响张量的一部分。因此,您可以做的一件事是创建一个 bool 掩码,对于超过 obj_threshold 的值或最后一个索引低于 5 的值,该掩码为 True

mask = (netout > obj_threshold) | (tf.range(tf.shape(netout)[-1]) < 5)

然后您可以将其与之前的任何方法一起使用:

netout = tf.where(mask, netout, tf.zeros_like(netout))

netout = netout * tf.cast(mask, netout.dtype)

关于python - Tensorflow,在给定条件下更改张量值,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/53213372/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com