gpt4 book ai didi

deep-learning - 索引错误 : index_select(): Index is supposed to be a vector

转载 作者:行者123 更新时间:2023-12-05 06:09:40 26 4
gpt4 key购买 nike

    for batch_id, (data, target) in enumerate(tqdm(train_loader)):
print(target)
print('Entered for loop')
target = torch.sparse.torch.eye(10).index_select(dim=0, index=target)
data, target = Variable(data), Variable(target)

包含 index_select 函数的行给出了这个错误,我无法在任何地方找到解决方案。打印时的目标变量如下所示:

tensor([[4],
[1],
[8],
[5],
[9],
[5],
[5],
[8],
[4],
[6]])

如何将目标变量转换为向量?不是已经是向量了吗?

最佳答案

如果您查看目标变量的形状,您会发现它是形状为以下的二维张量:

target.shape # torch.Size([10, 1])

错误信息有点困惑,但本质上索引应该是一维张量(向量)。所以使用 .squeeze 方法会:

target.squeeze().shape # torch.Size([10])

index_select 方法不会报错。

关于deep-learning - 索引错误 : index_select(): Index is supposed to be a vector,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/64693739/

26 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com