gpt4 book ai didi

神经网络冻结层,固定参数

转载 作者:知者 更新时间:2024-03-13 01:41:55 28 4
gpt4 key购买 nike

bn层,卷积层测试:

import torch
from torch import nn

def init_weights(m):
    if type(m) == torch.nn.Linear :
        m.weight.data=torch.ones_like(m.weight)
        m.bias.data = torch.ones_like(m.bias)
class Net(torch.nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, kernel_size=3, stride=2, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(6)
    def forward(self,x):
        a1=self.conv1(x)
        a2=self.bn1(a1)
        return a2

if __name__ == '__main__':

    net=Net()
    net.apply(init_weights)    #为了固定住网络的初始参数
    print(net.conv1.weight.grad,net.bn1.weight.grad,sep="\n")
    print("-------------")
    net.conv1.requires_grad_(False)
    net=net.train()
    x=torch.rand([2,3,8,8],dtype=torch.float32)
    y=net(x)
    y.sum().backward()
    print(net

28 4 0
文章推荐: Qt—坐标系统
文章推荐: MySQL数据存储
文章推荐: MySQL 触发器
文章推荐: MongoDB——聚合管道之$unwind操作
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com