gpt4 book ai didi

jax - 在 JAX 中高效计算 Hessian 矩阵

转载 作者:行者123 更新时间:2023-12-05 04:38:34 37 4
gpt4 key购买 nike

在 JAX 的快速入门教程中,我发现可以使用以下代码行为可微函数 fun 高效地计算 Hessian 矩阵:

from jax import jacfwd, jacrev

def hessian(fun):
return jit(jacfwd(jacrev(fun)))

但是,也可以通过计算以下内容来计算 Hessian:

def hessian(fun):
return jit(jacrev(jacfwd(fun)))

def hessian(fun):
return jit(jacfwd(jacfwd(fun)))

def hessian(fun):
return jit(jacrev(jacrev(fun)))

这是一个最小的工作示例:

import jax.numpy as jnp
from jax import jit
from jax import jacfwd, jacrev

def comp_hessian():

x = jnp.arange(1.0, 4.0)

def sum_logistics(x):
return jnp.sum(1.0 / (1.0 + jnp.exp(-x)))

def hessian_1(fun):
return jit(jacfwd(jacrev(fun)))

def hessian_2(fun):
return jit(jacrev(jacfwd(fun)))

def hessian_3(fun):
return jit(jacrev(jacrev(fun)))

def hessian_4(fun):
return jit(jacfwd(jacfwd(fun)))

hessian_fn = hessian_1(sum_logistics)
print(hessian_fn(x))

hessian_fn = hessian_2(sum_logistics)
print(hessian_fn(x))

hessian_fn = hessian_3(sum_logistics)
print(hessian_fn(x))

hessian_fn = hessian_4(sum_logistics)
print(hessian_fn(x))


def main():
comp_hessian()


if __name__ == "__main__":
main()

我想知道哪种方法最适合使用以及何时使用?我还想知道是否可以使用 grad() 来计算 Hessian? grad()jacfwdjacrev 有何不同?

最佳答案

您问题的答案在 JAX 文档中;例如,参见本节:https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html#jacobians-and-hessians-using-jacfwd-and-jacrev

引用它对 jacrevjacfwd 的讨论:

These two functions compute the same values (up to machine numerics), but differ in their implementation: jacfwd uses forward-mode automatic differentiation, which is more efficient for “tall” Jacobian matrices, while jacrev uses reverse-mode, which is more efficient for “wide” Jacobian matrices. For matrices that are near-square, jacfwd probably has an edge over jacrev.

再往下,

To implement hessian, we could have used jacfwd(jacrev(f)) or jacrev(jacfwd(f)) or any other composition of the two. But forward-over-reverse is typically the most efficient. That’s because in the inner Jacobian computation we’re often differentiating a function wide Jacobian (maybe like a loss function 𝑓:ℝⁿ→ℝ), while in the outer Jacobian computation we’re differentiating a function with a square Jacobian (since ∇𝑓:ℝⁿ→ℝⁿ), which is where forward-mode wins out.

由于您的函数看起来像 𝑓:ℝⁿ→ℝ,那么 jit(jacfwd(jacrev(fun))) 可能是最有效的方法。

至于为什么不能用grad 实现hessian,这是因为grad 只设计用于具有标量输出的函数的导数。根据定义,hessian 矩阵是向量值雅可比矩阵的组合,而不是标量梯度的组合。

关于jax - 在 JAX 中高效计算 Hessian 矩阵,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/70572362/

37 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com