gpt4 book ai didi

python - PyTorch torch.no_grad() 与 requires_grad=False

转载 作者:行者123 更新时间:2023-12-03 16:32:43 31 4
gpt4 key购买 nike

我正在关注 PyTorch tutorial它使用 Huggingface Transformers 库中的 BERT NLP 模型(特征提取器)。有两段我不明白的梯度更新相关代码。
(1) torch.no_grad()该教程有一个类,其中 forward()函数创建一个 torch.no_grad()阻止对 BERT 特征提取器的调用,如下所示:

bert = BertModel.from_pretrained('bert-base-uncased')

class BERTGRUSentiment(nn.Module):

def __init__(self, bert):
super().__init__()
self.bert = bert

def forward(self, text):
with torch.no_grad():
embedded = self.bert(text)[0]
(2) param.requires_grad = False在同一教程中还有另一部分卡住了 BERT 参数。
for name, param in model.named_parameters():                
if name.startswith('bert'):
param.requires_grad = False
我什么时候需要(1)和/或(2)?
  • 如果我想使用卡住的 BERT 进行训练,是否需要同时启用两者?
  • 如果我想训练让 BERT 更新,我是否需要同时禁用两者?

  • 另外,我运行了所有四种组合,发现:
       with torch.no_grad   requires_grad = False  Parameters  Ran
    ------------------ --------------------- ---------- ---
    a. Yes Yes 3M Successfully
    b. Yes No 112M Successfully
    c. No Yes 3M Successfully
    d. No No 112M CUDA out of memory
    有人可以解释一下发生了什么吗? 为什么我收到 CUDA out of memory对于(d)而不是(b)?两者都有 112M 的可学习参数。

    最佳答案

    这是一个较旧的讨论,多年来略有变化(主要是由于 with torch.no_grad() 作为一种模式的目的。可以在 on Stackoverflow already 中找到一个很好的答案,也可以回答您的问题。
    但是,由于原始问题大不相同,因此我不会将其标记为重复,特别是由于关于内存的第二部分。no_grad的初步解释给出 here :

    with torch.no_grad() is a context manager and is used to prevent calculating gradients [...].

    requires_grad另一方面使用

    to freeze part of your model and train the rest [...].


    再次来源 the SO post .
    本质上,与 requires_grad您只是在禁用部分网络,而 no_grad根本不会存储任何梯度,因为您可能将它用于推理而不是训练。
    为了分析您的参数组合的行为,让我们调查发生了什么:
  • a)b)根本不存储任何梯度,这意味着无论参数数量如何,您都可以使用更多的内存,因为您不会保留它们以用于潜在的向后传递。
  • c)必须为以后的反向传播存储前向传播,但是,只存储了有限数量的参数(300 万),这使得这仍然可以管理。
  • d)但是,需要存储所有 1.12 亿个参数的前向传递,这会导致内存不足。
  • 关于python - PyTorch torch.no_grad() 与 requires_grad=False,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/63785319/

    31 4 0
    Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
    广告合作:1813099741@qq.com 6ren.com