gpt4 book ai didi

deep-learning - 优化器的 state_dict 中保存了什么? "state","param_groups"代表什么?

转载 作者:行者123 更新时间:2023-12-04 12:19:32 27 4
gpt4 key购买 nike

当我们使用 Adam 优化器时,如果我们想从预训练模型继续训练网络,我们不仅要加载“model.state_dict”,还要加载“optimizer.state_dict”。而且,如果我们修改了我们的网络结构,我们还应该修改保存的优化器的 state_dict 以使我们的加载成功。

但我不明白保存的“optimizer.state_dict”中的一些参数。像 optim_dict["state"] (dict_keys(['step', 'exp_avg', 'exp_avg_sq', 'max_exp_avg_sq'])) 和 optim_dict['param_groups'][0]['params']。有很多这样的数字:

 b['optimizer_state_dict']['state'].keys()
Out[71]: dict_keys([140623218628000, 140623218628072, 140623218628216, 140623218628360, 140623218628720, 140623218628792, 140623218628936, 140623218629080, 140623218629656, 140623218629728, 140623218629872, 140623218630016, 140623218630376, 140623218630448, 140623218716744, 140623218716816, 140623218717392, 140623218717464, 140623218717608, 140623218717752, 140623218718112, 140623218718184, 140623218718328, 140623218718472, 140623218719048, 140623218719120, 140623218719264, 140623218719408, 140623218719768, 140623218719840, 140623218719984, 140623218720128, 140623218720704, 140623209943112, 140623209943256, 140623209943400, 140623209943760, 140623209943832, 140623209943976, 140623209944120, 140623209944696, 140623209944768, 140623209944912, 140623209945056, 140623209945416, 140623209945488, 140623209945632, 140623209945776, 140623209946352, 140623209946424, 140623209946568, 140623209946712, 140623209947072, 140623210041416, 140623210041560, 140623210041704, 140623244033768, 140623244033840, 140623244033696, 140623244033912, 140623244033984, 140623244070984, 140623244071056, 140623244071128, 140623429501576, 140623244071200, 140623244071272, 140623244071344, 140623244071416, 140623244071488, 140623244071560, 140623244071632, 140623244071848, 140623244071920, 140623244072064, 140623244072208, 140623244072424, 140623244072496, 140623244072640, 140623244072784, 140623244073216, 140623244073288, 140623244073432, 140623244073576, 140623244073792, 140623244073864, 140623244074008, 140623244074152, 140623244074584, 140623244074656, 140623244074800, 140623244074944, 140623218540760, 140623218540832, 140623218540976, 140623218541120, 140623218541552, 140623218541624, 140623218541768, 140623218541912, 140623218542128, 140623218542200, 140623218542344, 140623218542488, 140623218542920, 140623218542992, 140623218543136, 140623218543280, 140623218543496, 140623218543568, 140623218543712, 140623218543856, 140623218544288, 140623218544360, 140623218544504, 140623218626632, 140623218626992, 140623218627064, 140623218627208, 140623218627352, 140623218627784, 140623218629440, 140623218717176, 140623218718832, 140623218720488, 140623209944480, 140623209946136, 140623210043000])
In [44]: b['optimizer_state_dict']['state'][140623218628072].keys()
Out[44]: dict_keys(['step', 'exp_avg', 'exp_avg_sq', 'max_exp_avg_sq'])

In [45]: b['optimizer_state_dict']['state'][140623218628072]['exp_avg'].shape
Out[45]: torch.Size([480])

最佳答案

与模型的 state_dict 对比,保存可学习参数,优化器的 state_dict包含有关优化器状态(要优化的参数)以及使用的超参数的信息。

PyTorch 中的所有优化器都需要继承自基类 torch.optim.Optimizer .它需要两个条目:

  • params (iterable)torch.Tensor 的迭代s 或 dict s。指定应该优化哪些张量。
  • defaults (dict) :dict包含优化选项的默认值(当参数组未指定它们时使用)。

  • 除此之外,优化器还支持指定每个参数的选项。

    To do this, instead of passing an iterable of Tensors, pass in an iterable of dicts. Each of them will define a separate parameter group, and should contain a params key, containing a list of parameters belonging to it.



    考虑一个例子,
    optim.SGD([
    {'params': model.base.parameters()},
    {'params': model.classifier.parameters(), 'lr': 1e-3}
    ], lr=1e-2, momentum=0.9)

    在这里,我们提供了 a) params , b) 默认超参数: lr , momentum , 和 c) 一个参数组。在这种情况下, model.base的参数将使用默认学习率1e-2, model.classifier的参数将使用 1e-3 的学习率,所有参数将使用 0.9 的动量。
    step ( optimizer.step() ) 执行单个优化步骤(参数更新),这会更改优化器的状态。

    现在,来到优化器的 state_dict ,它将优化器的状态返回为 dict .它包含两个条目:
  • state - 一个 dict保持当前的优化状态。
  • param_groups - 一个 dict包含所有参数组(如上所述)


  • 一些超参数特定于使用的优化器或模型,例如(用于亚当)
  • exp_avg :梯度值的指数移动平均
  • exp_avg_sq :平方梯度值的指数移动平均
  • 关于deep-learning - 优化器的 state_dict 中保存了什么? "state","param_groups"代表什么?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/62260985/

    27 4 0
    Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
    广告合作:1813099741@qq.com 6ren.com