gpt4 book ai didi

python - 为什么 dim=1 在 torch.argmax 中返回行索引?

转载 作者:行者123 更新时间:2023-12-01 08:02:04 24 4
gpt4 key购买 nike

我正在研究 PyTorch 的 argmax 函数,其定义为:

torch.argmax(input, dim=None, keepdim=False)

考虑一个例子

a = torch.randn(4, 4)
print(a)
print(torch.argmax(a, dim=1))

这里,当我使用 dim=1 而不是搜索列向量时,该函数会搜索行向量,如下所示。

print(a) :   
tensor([[-1.7739, 0.8073, 0.0472, -0.4084],
[ 0.6378, 0.6575, -1.2970, -0.0625],
[ 1.7970, -1.3463, 0.9011, -0.8704],
[ 1.5639, 0.7123, 0.0385, 1.8410]])

print(torch.argmax(a, dim=1))
tensor([1, 1, 0, 3])

据我的假设,dim = 0 代表行,dim =1 代表列。

最佳答案

是时候正确理解axisdim参数在PyTorch中的工作原理了:

tensor dimension

<小时/>

一旦您理解了上图,下面的示例就应该有意义了:

    |
v
dim-0 ---> -----> dim-1 ------> -----> --------> dim-1
| [[-1.7739, 0.8073, 0.0472, -0.4084],
v [ 0.6378, 0.6575, -1.2970, -0.0625],
| [ 1.7970, -1.3463, 0.9011, -0.8704],
v [ 1.5639, 0.7123, 0.0385, 1.8410]]
|
v
# argmax (indices where max values are present) along dimension-1
In [215]: torch.argmax(a, dim=1)
Out[215]: tensor([1, 1, 0, 3])
<小时/>

注意:dim('dimension'的缩写)在 torch 中相当于'axis' NumPy。

关于python - 为什么 dim=1 在 torch.argmax 中返回行索引?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/55691819/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com