gpt4 book ai didi

python - 在 PyTorch 中快速计算模型参数的 Hessian 矩阵

转载 作者:行者123 更新时间:2023-12-05 05:29:40 41 4
gpt4 key购买 nike

我想计算损失 w.r.t 的 Hessian 矩阵。 PyTorch 中的模型参数,但使用 torch.autograd.functional.hessian对我来说不是一个选择,因为它重新计算了我之前调用中已经拥有的模型输出和损失。我目前的实现如下:

import torch
import time

# Create model
model = torch.nn.Sequential(torch.nn.Linear(1, 100), torch.nn.Tanh(), torch.nn.Linear(100, 1))
num_param = sum(p.numel() for p in model.parameters())

# Evaluate some loss on a random dataset
x = torch.rand((1000,1))
y = torch.rand((1000,1))
y_hat = model(x)
loss = ((y_hat - y)**2).mean()

''' Calculate Hessian '''
start = time.time()

# Allocate Hessian size
H = torch.zeros((num_param, num_param))

# Calculate Jacobian w.r.t. model parameters
J = torch.autograd.grad(loss, list(model.parameters()), create_graph=True)
J = torch.cat([e.flatten() for e in J]) # flatten

# Fill in Hessian
for i in range(num_param):
result = torch.autograd.grad(J[i], list(model.parameters()), retain_graph=True)
H[i] = torch.cat([r.flatten() for r in result]) # flatten

print(time.time() - start)

有什么方法可以更快地做到这一点?也许不使用 for 循环,因为它正在为每个模型变量调用 autograd.grad

最佳答案

一种使其更快的方法是使用 functorch.hessian(基于 this issue ),但是每次计算 Hessian 矩阵时它都必须重新计算损失(虽然我已经可以访问损失).不过,我会为那些感兴趣的人发布它。我仍然认为它太慢了。

import torch
from functorch import hessian
from torch.nn.utils import _stateless
import time

# Create model
model = torch.nn.Sequential(torch.nn.Linear(1, 100), torch.nn.Tanh(), torch.nn.Linear(100, 1))
num_param = sum(p.numel() for p in model.parameters())
names = list(n for n, _ in model.named_parameters())

# Create random dataset
x = torch.rand((1000,1))
y = torch.rand((1000,1))

# Define loss function
def loss(params):
y_hat = _stateless.functional_call(model, {n: p for n, p in zip(names, params)}, x)
return ((y_hat - y)**2).mean()

# Calculate Hessian
hessian_func = hessian(loss)

start = time.time()

H = hessian_func(tuple(model.parameters()))
H = torch.cat([torch.cat([e.flatten() for e in Hpart]) for Hpart in H]) # flatten
H = H.reshape(num_param, num_param)

print(time.time() - start)

关于python - 在 PyTorch 中快速计算模型参数的 Hessian 矩阵,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/74900770/

41 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com