gpt4 book ai didi

python - 如何将 torch int64 转换为 torch LongTensor?

转载 作者:行者123 更新时间:2023-12-01 07:44:53 40 4
gpt4 key购买 nike

我正在学习一门类(class),该类(class)使用已弃用的 PyTorch 版本,该版本不会根据需要将 torch.int64 更改为 torch.LongTensor。当前引发错误的代码部分是:

loss = loss_fn(Ypred, Ytrain_) # 预测的计算损失

我认为应该在本节中更改数据类型:

Ytrain_ = torch.from_numpy(y_train.values).view(1, -1)[0]

使用Ytrain_.dtype测试数据类型时,它返回torch.int64。我尝试通过应用 long() 函数来转换它,如下所示:Ytrain_ = Ytrain_.long() 无济于事。

我也尝试过在 documentation 中寻找它但它似乎说 torch.int64 OR torch.long 我认为这意味着 torch.int64 应该可以工作。

RuntimeError                              Traceback (most recent call last)
----> 9 loss = loss_fn(Ypred, Ytrain_) # calc loss on the prediction
RuntimeError: Expected object of scalar type Long but got scalar type Int for argument #2 'target'

最佳答案

正如user8426627所述,您想要更改张量类型,而不是数据类型。因此,解决方案是添加 .type(torch.LongTensor) 将其转换为 LongTensor

最终代码:

Ytrain_ = torch.from_numpy(Y_train.values).view(1, -1)[0].type(torch.LongTensor)

测试张量类型:

Ytrain_.type()

'torch.LongTensor'

关于python - 如何将 torch int64 转换为 torch LongTensor?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/56510189/

40 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com