gpt4 book ai didi

Lora训练的参数和性能

转载 作者:撒哈拉 更新时间:2024-05-08 17:14:01 76 4
gpt4 key购买 nike

主要为了测试模型增加Lora模块后,参数量和训练速度的变化情况。
结论:正常情况下,增加Lora模块是会增加参数量的,因此前向传播和反向传播的时间也会增加。
但是,在大语言模型训练的情况下,因为基础模型本身参数量非常大,Lora模块增加的参数量相对非常小。并且,基础模型不参与梯度更新,可以做模型量化,实际上是能减少模型训练时间和显存使用量的。
以下是实验脚本和运行结果:
#部分参考https://zhuanlan.zhihu.com/p/666000885
import time import torch from torch import nn from peft import LoraConfig, get_peft_model, PeftModel from torchsummary import summary x_train = torch.randn((1000, 10)) y_train = torch.randn((1000, 1)) net = nn.Sequential( nn.Linear(10,20), nn.Sigmoid(), nn.Linear(20,30), nn.Sigmoid(), nn.Linear(30,1) ) summary(net, (1,10)) config = LoraConfig(target_modules=["0"], r=2) model = get_peft_model(net, config) criterion = torch.nn.MSELoss(reduction='mean') # 定义损失函数,采用均方误差 optimizer = torch.optim.Adam(model.parameters(), lr=0.3) # 定义优化器,采用Adam summary(model, (1,10)) # base 前向计算时间 start = time.time() for i in range(100000): y_pre = net(x_train) # 前向传播 print("base 前向计算时间: ", time.time() - start) # lora 前向计算时间 start = time.time() for i in range(100000): y_pre = model(x_train) # 前向传播 print("lora 前向计算时间", time.time() - start) # base 反向传播时间 start = time.time() for i in range(1000): y_pre = net(x_train) # 前向传播 loss = criterion(y_pre, y_train) # 计算损失 optimizer.zero_grad() # 梯度清零 loss.backward() # 反向传播 optimizer.step() # 使用优化器更新梯度 print("base loss after training: ", loss.item()) print("base 反向计算时间", time.time() - start) # lora 反向传播时间 start = time.time() for i in range(1000): y_pre = model(x_train) # 前向传播 loss = criterion(y_pre, y_train) # 计算损失 optimizer.zero_grad() # 梯度清零 loss.backward() # 反向传播 optimizer.step() # 使用优化器更新梯度 print("lora loss after training: ", loss.item()) print("lora 反向计算时间", time.time() - start)

  运行代码输出:

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Linear-1                [-1, 1, 20]             220
           Sigmoid-2                [-1, 1, 20]               0
            Linear-3                [-1, 1, 30]             630
           Sigmoid-4                [-1, 1, 30]               0
            Linear-5                 [-1, 1, 1]              31
================================================================
Total params: 881
Trainable params: 881
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 0.00
Params size (MB): 0.00
Estimated Total Size (MB): 0.00
----------------------------------------------------------------
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Linear-1                [-1, 1, 20]             220
          Identity-2                [-1, 1, 10]               0
            Linear-3                 [-1, 1, 2]              20
            Linear-4                [-1, 1, 20]              40
            Linear-5                [-1, 1, 20]             220
           Sigmoid-6                [-1, 1, 20]               0
            Linear-7                [-1, 1, 30]             630
           Sigmoid-8                [-1, 1, 30]               0
            Linear-9                 [-1, 1, 1]              31
================================================================
Total params: 1,161
Trainable params: 60
Non-trainable params: 1,101
----------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 0.00
Params size (MB): 0.00
Estimated Total Size (MB): 0.01
----------------------------------------------------------------
base loss after training:  1.0724023580551147
base 反向计算时间 2.9570980072021484
lora loss after training:  1.0643658638000488
lora 反向计算时间 3.053032159805298

最后此篇关于Lora训练的参数和性能的文章就讲到这里了,如果你想了解更多关于Lora训练的参数和性能的内容请搜索CFSDN的文章或继续浏览相关文章,希望大家以后支持我的博客! 。

76 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com