gpt4 book ai didi

python - JAX:为什么不使用 @jit 产生 -inf 值但使用它却没有?

转载 作者:行者123 更新时间:2023-12-05 05:35:26 25 4
gpt4 key购买 nike

我正在摆弄 JAX,我通过使用 jit 装饰器遇到了两个不同的结果

import jax
import jax.numpy as jnp
import jax.scipy.stats as jstats


def jitless_log_likelihood(x, mu, sigma):
return jnp.sum(jnp.log(jstats.multivariate_normal.pdf(x, mean=mu, cov=sigma)))

@jax.jit
def log_likelihood(x, mu, sigma):
return jitless_log_likelihood(x, mu, sigma)


key = jax.random.PRNGKey(0)

M = 10000

x = jax.random.normal(key, (10,M))
mu = jnp.array([0]*M)
sigma = jnp.identity(M)


print(jitless_log_likelihood(x, mu, sigma))
print(log_likelihood(x, mu, sigma))

在我的 CPU 中,产生以下代码

-141839.1 -inf

为什么会这样?

最佳答案

一般来说,JIT 编译会重新排列、组合或省略函数中的操作以提高效率,这有时会改变函数的数值结果。有关详细信息和更多解释,请参阅 FAQ: jit changes the exact numerics of outputs .在这种情况下,您计算指数数量 (multivariate_normal.pdf) 的 jnp.log 很可能是罪魁祸首。

这是查看相同行为的更简单方法:

from jax import jit
import jax.numpy as jnp

def f(x):
return jnp.log(jnp.exp(x))

x = -141839.1

print(f(x)) # -inf
print(jit(f)(x)) # -141839.1

在 JIT 编译的函数中,编译器注意到 explog 抵消了,因此省略了这些操作,避免了在计算它们时发生的下溢顺序。

您可以使用为此目的构建的 logpdf 函数,通过避免首先取指数对数来实现函数的更好的行为版本:

def jitless_log_likelihood(x, mu, sigma):
return jnp.sum(jstats.multivariate_normal.logpdf(x, mean=mu, cov=sigma))

关于python - JAX:为什么不使用 @jit 产生 -inf 值但使用它却没有?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/73513353/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com