gpt4 book ai didi

python - 为什么 PyTorch nn.Module.cuda() 不移动模块张量而只移动参数和缓冲区到 GPU?

转载 作者:行者123 更新时间:2023-12-03 14:15:37 27 4
gpt4 key购买 nike

nn.Module.cuda()将所有模型参数和缓冲区移动到 GPU。

但为什么不是模型成员张量?

class ToyModule(torch.nn.Module):
def __init__(self) -> None:
super(ToyModule, self).__init__()
self.layer = torch.nn.Linear(2, 2)
self.expected_moved_cuda_tensor = torch.tensor([0, 2, 3])

def forward(self, input: torch.Tensor) -> torch.Tensor:
return self.layer(input)

toy_module = ToyModule()
toy_module.cuda()
next(toy_module.layer.parameters()).device
>>> device(type='cuda', index=0)

对于模型成员张量,设备保持不变。
>>> toy_module.expected_moved_cuda_tensor.device
device(type='cpu')

最佳答案

如果您在模块内定义张量,则需要将其注册为参数或缓冲区,以便模块知道它。

参数是要训练的张量,将由 model.parameters() 返回.它们很容易注册,您需要做的就是将张量包装在 nn.Parameter 中类型,它将自动注册。请注意,只有浮点张量可以作为参数。

class ToyModule(torch.nn.Module):
def __init__(self) -> None:
super(ToyModule, self).__init__()
self.layer = torch.nn.Linear(2, 2)
# registering expected_moved_cuda_tensor as a trainable parameter
self.expected_moved_cuda_tensor = torch.nn.Parameter(torch.tensor([0., 2., 3.]))

def forward(self, input: torch.Tensor) -> torch.Tensor:
return self.layer(input)

缓冲区是将在模块中注册的张量,因此像 .cuda() 这样的方法会影响它们,但它们不会被 model.parameters() 返回.缓冲区不限于特定的数据类型。
class ToyModule(torch.nn.Module):
def __init__(self) -> None:
super(ToyModule, self).__init__()
self.layer = torch.nn.Linear(2, 2)
# registering expected_moved_cuda_tensor as a buffer
# Note: this creates a new member variable named expected_moved_cuda_tensor
self.register_buffer('expected_moved_cuda_tensor', torch.tensor([0, 2, 3])))

def forward(self, input: torch.Tensor) -> torch.Tensor:
return self.layer(input)

在上述两种情况下,以下代码的行为相同
>>> toy_module = ToyModule()
>>> toy_module.cuda()
>>> next(toy_module.layer.parameters()).device
device(type='cuda', index=0)
>>> toy_module.expected_moved_cuda_tensor.device
device(type='cuda', index=0)

关于python - 为什么 PyTorch nn.Module.cuda() 不移动模块张量而只移动参数和缓冲区到 GPU?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/60908827/

27 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com