gpt4 book ai didi

python - torch.argmax 如何为 4 维工作

转载 作者:行者123 更新时间:2023-12-04 09:06:43 24 4
gpt4 key购买 nike

我是 Pytorch 的新手。即使我阅读了文档,我也不清楚当我们有 4 维输入时,torch.argmax() 如何应用于第一维工作。另外, keepdims=True 如何改变输出?
以下是每种情况的示例:

k = torch.rand(2, 3, 4, 4)
print(k):

tensor([[[[0.2912, 0.4818, 0.1123, 0.3196],
[0.6606, 0.1547, 0.0368, 0.9475],
[0.4753, 0.7428, 0.5931, 0.3615],
[0.6729, 0.7069, 0.1569, 0.3086]],

[[0.6603, 0.7777, 0.3546, 0.2850],
[0.3681, 0.5295, 0.8812, 0.6093],
[0.9165, 0.2842, 0.0260, 0.1768],
[0.9371, 0.9889, 0.6936, 0.7018]],

[[0.5880, 0.0349, 0.0419, 0.3913],
[0.5884, 0.9408, 0.1707, 0.1893],
[0.3260, 0.4410, 0.6369, 0.7331],
[0.9448, 0.7130, 0.3914, 0.2775]]],


[[[0.9433, 0.8610, 0.9936, 0.1314],
[0.8627, 0.3103, 0.3066, 0.3547],
[0.3396, 0.1892, 0.0385, 0.5542],
[0.4943, 0.0256, 0.7875, 0.5562]],

[[0.2338, 0.2498, 0.4749, 0.2520],
[0.4405, 0.1605, 0.6219, 0.8955],
[0.2326, 0.1816, 0.5032, 0.8732],
[0.2089, 0.6131, 0.1898, 0.0517]],

[[0.1472, 0.8059, 0.6958, 0.9047],
[0.6403, 0.2875, 0.5746, 0.5908],
[0.8668, 0.4602, 0.8224, 0.9307],
[0.2077, 0.5665, 0.8671, 0.4365]]]])

argmax = torch.argmax(k, axis=1)
print(argmax):
tensor([[[1, 1, 1, 2],
[0, 2, 1, 0],
[1, 0, 2, 2],
[2, 1, 1, 1]],

[[0, 0, 0, 2],
[0, 0, 1, 1],
[2, 2, 2, 2],
[0, 1, 2, 0]]])


argmax = torch.argmax(k, axis=1, keepdims=True)
print(argmax):
tensor([[[[1, 1, 1, 2],
[0, 2, 1, 0],
[1, 0, 2, 2],
[2, 1, 1, 1]]],


[[[0, 0, 0, 2],
[0, 0, 1, 1],
[2, 2, 2, 2],
[0, 1, 2, 0]]]])

最佳答案

k是形状张量 (2, 3, 4, 4) ,根据定义,torch.argmaxaxis=1应该给你一个形状 (2, 4, 4) 的输出.要了解为什么会发生这种情况,您必须首先了解较低维度中会发生什么。
如果我有一个 2D (2, 2) 张量 A,例如:

[[1,2],
[3,4]]
然后 torch.argmax(A, axis=1)给出具有值 (1, 1) 的形状 (2) 的输出。轴参数表示要操作的轴。所以设置 axis=1意味着它会在决定最大值之前一一查看每一列的值。对于第 0 行,它查看列值 1、2 并确定 2(在索引 1 处)是最大值。对于第 1 行,它查看列值 3、4 并确定 4(在索引 1 处)是最大值。所以 argmax 结果是 [1, 1]。
向上移动到 3D,让我们有一个假设的维度数组(I、J、K)。如果我们使用axis = 1调用argmax,我们可以将其分解为以下内容:
I, J, K = 3, 4, 5
A = torch.rand(I, J, K)
out = torch.zeros((I, K), dtype=torch.int32)

for i in range(I):
for k in range(K):
out[i,k] = torch.argmax(A[i,:,k])

print(out)
print(torch.argmax(A, axis=1))

Out:
tensor([[3, 3, 2, 3, 2],
[1, 1, 0, 1, 0],
[0, 1, 0, 3, 3]], dtype=torch.int32)
tensor([[3, 3, 2, 3, 2],
[1, 1, 0, 1, 0],
[0, 1, 0, 3, 3]])
那么发生的情况是,在您的 3D 张量中,您再次沿列/轴 1 计算 argmax。因此,对于每个唯一的 (i, k) 对,您在轴 1 上正好有 J 值,对吗?这些 J 值中最大值的索引被插入到输出的位置 (i,k) 中。
如果您了解这一点,那么您就可以了解 4D 中发生的事情。对于维度 (I, J, K, L) 的任何 4D 张量,如果您使用轴 = 1 调用 argmax,那么对于 (i, k, l) 的每个组合,您将沿着轴 1 精确地具有 J 值 - 以及这些 J 值的 argmax 将出现在输出 [i,k,l] 中。 keepdims参数只是保留矩阵的维数。例如,4D 矩阵上轴 1 处的 argmax 给出形状为 (I,K,L) 的 3D 结果,但使用 keepdims,结果也将是形状为 (I,1,K,L) 的 4D。

关于python - torch.argmax 如何为 4 维工作,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/63427246/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com