gpt4 book ai didi

pytorch - pytorch 中的 model.cuda()

转载 作者:行者123 更新时间:2023-12-02 10:44:28 25 4
gpt4 key购买 nike

如果我在 pytorch 中调用 model.cuda(),其中 model 是 nn.Module 的子类,并假设我有四个 GPU,它将如何利用四个 GPU,我如何知道正在使用哪些 GPU?

最佳答案

如果您在 model.cuda() 所有模型参数之后有一个从 nn.Module 派生的自定义模块,(model.parameters()) > 迭代器可以向您显示这些)将在您的 cuda 上结束。

要检查你的参数在哪里,只需在我的例子中打印它们(cuda:0):

class M(nn.Module):
'custom module'
def __init__(self):
super().__init__()
self.lin = nn.Linear(784, 10)

m = M()
m.cuda()
for _ in m.parameters():
print(_)

# Parameter containing:
# tensor([[-0.0201, 0.0282, -0.0258, ..., 0.0056, 0.0146, 0.0220],
# [ 0.0098, -0.0264, 0.0283, ..., 0.0286, -0.0052, 0.0007],
# [-0.0036, -0.0045, -0.0227, ..., -0.0048, -0.0003, -0.0330],
# ...,
# [ 0.0217, -0.0008, 0.0029, ..., -0.0213, 0.0005, 0.0050],
# [-0.0050, 0.0320, 0.0013, ..., -0.0057, -0.0213, 0.0045],
# [-0.0302, 0.0315, 0.0356, ..., 0.0259, 0.0166, -0.0114]],
# device='cuda:0', requires_grad=True)
# Parameter containing:
# tensor([-0.0027, -0.0353, -0.0349, -0.0236, -0.0230, 0.0176, -0.0156, 0.0037,
# 0.0222, -0.0332], device='cuda:0', requires_grad=True)

您还可以像这样指定设备:

m.cuda('cuda:0')

使用torch.cuda.device_count(),您可以检查您拥有的设备数量。

关于pytorch - pytorch 中的 model.cuda(),我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/56852347/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com