gpt4 book ai didi

python - 后端 CPU 的预期对象,但参数 #2 得到后端 CUDA 'source'

转载 作者:行者123 更新时间:2023-12-01 00:49:16 25 4
gpt4 key购买 nike

我尝试了其他答案,但错误没有被删除。与我收到的其他问题的区别在于,错误中使用的最后一个术语是“来源”,我在任何问题中都没有找到它。如果可能的话,还请解释错误的术语“来源”。在没有 CPU 的情况下运行代码也可以正常工作。

I am using Google Colab with GPU enabled.

import torch
from torch import nn
import syft as sy

hook = sy.TorchHook(torch)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = nn.Sequential(nn.Linear(784,256),
nn.ReLU(),
nn.Linear(256,128),
nn.ReLU(),
nn.Linear(128,64),
nn.ReLU(),
nn.Linear(64,10),
nn.LogSoftmax(dim = 1))

model = model.to(device)

输出:

---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
<ipython-input-42-136ec343040a> in <module>()
8 nn.LogSoftmax(dim = 1))
9
---> 10 model = model.to(device)

3 frames
/usr/local/lib/python3.6/dist-packages/syft/frameworks/torch/hook/hook.py in data(self, new_data)
368
369 with torch.no_grad():
--> 370 self.set_(new_data)
371 return self
372

RuntimeError: Expected object of backend CPU but got backend CUDA for argument #2 'source'

最佳答案

此问题与 PySyft 有关。正如您在 Issue#1893 中看到的那样,当前workaround是要设置:

import torch
torch.set_default_tensor_type(torch.cuda.FloatTensor)

就在导入 torch 之后。

代码:

import torch
from torch import nn
torch.set_default_tensor_type(torch.cuda.FloatTensor) # <-- workaround

import syft as sy
hook = sy.TorchHook(torch)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

model = nn.Sequential(nn.Linear(784,256),
nn.ReLU(),
nn.Linear(256,128),
nn.ReLU(),
nn.Linear(128,64),
nn.ReLU(),
nn.Linear(64,10),
nn.LogSoftmax(dim = 1))

model = model.to(device)
print(model)

输出:

cuda
Sequential(
(0): Linear(in_features=784, out_features=256, bias=True)
(1): ReLU()
(2): Linear(in_features=256, out_features=128, bias=True)
(3): ReLU()
(4): Linear(in_features=128, out_features=64, bias=True)
(5): ReLU()
(6): Linear(in_features=64, out_features=10, bias=True)
(7): LogSoftmax()
)

关于python - 后端 CPU 的预期对象,但参数 #2 得到后端 CUDA 'source',我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/56702413/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com