gpt4 book ai didi

python - 尺寸 M < 32 的 Pytorch 张量索引错误?

转载 作者:行者123 更新时间:2023-12-05 05:46:28 26 4
gpt4 key购买 nike

我正在尝试通过索引矩阵访问 pytorch 张量,我最近发现这段代码我找不到它不起作用的原因。

下面的代码分为两部分。前半部分证明有效,而后半部分出现错误。我看不出原因。有人可以阐明这一点吗?

import torch
import numpy as np

a = torch.rand(32, 16)
m, n = a.shape
xx, yy = np.meshgrid(np.arange(m), np.arange(m))
result = a[xx] # WORKS for a torch.tensor of size M >= 32. It doesn't work otherwise.

a = torch.rand(16, 16)
m, n = a.shape
xx, yy = np.meshgrid(np.arange(m), np.arange(m))
result = a[xx] # IndexError: too many indices for tensor of dimension 2

如果我更改 a = np.random.rand(16, 16) 它也能正常工作。

最佳答案

致所有前来寻找答案的人:它看起来像是 pyTorch 中的一个错误。

使用 numpy 数组进行索引的定义不明确张量使用张量进行索引时才有效。因此,在我的示例代码中,这可以完美地工作:

a = torch.rand(M, N)
m, n = a.shape
xx, yy = torch.meshgrid(torch.arange(m), torch.arange(m), indexing='xy')
result = a[xx] # WORKS

我做了一个gist to check it, and it's available here

关于python - 尺寸 M < 32 的 Pytorch 张量索引错误?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/71176095/

26 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com