gpt4 book ai didi

python - 如何自己实现tf.argmax?

转载 作者:行者123 更新时间:2023-11-30 08:48:38 28 4
gpt4 key购买 nike

我想使用一个函数,该函数将张量作为输入并返回张量各轴上最大值的索引。我知道存在一个函数 tf.argmax() 具有完全相同的功能,但是我如何自己实现它(在实现某些自定义函数时这可能是必要的)?

现在我们假设该函数仅将一维张量作为输入。因此,该函数需要具有以下签名:

argmax(
input, #input is a 1D tensor
name=None
)

我尝试这样实现:

def argmax(input, name=None):
maxValue=0
maxIndex=0
for i in range(input.get_shape()[0]):
if input[i]>maxValue:
maxValue=input[i]
maxIndex=i
return maxIndex

但是这不起作用,因为在构造阶段,这些值尚未初始化,因此我无法像在上面的代码中那样比较两个值。那么,有没有一种方法可以编写自定义函数,例如 tf.argmax、tf.equal 等?

最佳答案

嗯,一种简单的方法是这样的:

idx = tf.where(tf.equal(input, tf.reduce_max(input)))[0, 0]

示例:

import tensorflow as tf

with tf.Session() as sess:
input = tf.constant([1, 3, 4, 2, 1, 2])
idx = tf.where(tf.equal(input, tf.reduce_max(input)))[0, 0]
print(sess.run(idx))

输出:

2

关于python - 如何自己实现tf.argmax?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/52370630/

28 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com