gpt4 book ai didi

python - JAX/JIT 与标准 Numpy 性能 : where I am wrong?

转载 作者:行者123 更新时间:2023-12-02 18:33:27 26 4
gpt4 key购买 nike

这是一个使用 Simpson 集成代码的简单练习,我已经编写了该代码来接受多个函数以在一组边界上进行集成

import numpy as np
def simps(f, a, b, N):
#N should be even
dx = (b - a) / N
x = np.linspace(a, b, N + 1)
y = f(x)
w = np.ones_like(y)
w[2:-1:2] = 2.
w[1::2] = 4.
S = dx / 3 * np.einsum("i...,i...",w,y)
return S

def funcN(x):
return np.stack([x**(i/10) * np.exp(-x) for i in range(200)],axis=1)

a = np.arange(0,10,0.1)
b = a+0.05

我在一个 CPU 设备上,然后我得到一个 200 x 100 数字数组,对应于Int(f_i, a_j,b_j) i:0-199 和 j:0-99

%timeit simps(funcN,a,b, 512)

每次循环 1.13 s ± 27.4 ms(7 次运行的平均值 ± 标准差,每次 1 次循环)

现在考虑以下 JAX/JIT 版本

import jax
import jax.numpy as jnp
from jax import grad, jit, vmap
from functools import partial
from jax.config import config
config.update("jax_enable_x64", True) #numpy by default is in double precision

@partial(jit, static_argnums=(0,3))
def jax_simps(f, a,b, N):
dx = (b - a) / N
x = jnp.linspace(a, b, N + 1)
y = f(x)
w = jnp.ones_like(y)
w = w.at[2:-1:2].set(2.)
w = w.at[1::2].set(4.)
S = dx / 3. * jnp.einsum('i...,i...',w,y)
return S

@jit
def jax_funcN(x):
return jnp.stack([x**(i/10) * jnp.exp(-x) for i in range(200)],axis=1)

ja = jnp.arange(0,10,0.1)
jb = ja+0.05

#warm up
jax_simps(jax_funcN,ja,jb, 512).block_until_ready()

%timeit jax_simps(jax_funcN,ja,jb, 512).block_until_ready()

我已经验证这两个代码(纯 Numpy 和 JAX/JIT)给出了相同的结果因为最大相对误差约为 8. 10^-16。

现在,我得到了以下时间每个循环 933 ms ± 51.4 ms(7 次运行的平均值 ± 标准差,每次 1 个循环)

这非常接近纯粹的 Numpy。我是否偶然编写了一个非常高效的纯 Numpy 代码???或者我是否以错误的方式编码了 JAX/JIT?

(注意。使用 Google collab K80 GPU,JAX/JIT 的时间下降到每个循环 7.19 毫秒,将纯 Numpy 保持在 1 秒/循环的水平)

最佳答案

从你的数字来看,JAX JIT 在 CPU 上的速度比 NumPy 提高了 20%。对于 CPU 执行,NumPy 已经相当理想:抛开 autodiff 之类的东西,对于类似 NumPy 的操作的短序列,JAX 在 CPU 上的主要优势是 XLA 能够融合操作以避免为中间结果分配临时数组,并且对于这个相对较短的序列操作顺序,看起来只能带来大约 20% 的改进。

现在,JAX 还有其他优势,包括自动差异、批处理以及(正如您提到的)无需更改代码即可降低至加速器的能力。但对于在 CPU 上执行短序列向量化操作,通常无法比单独使用 NumPy 做得更好。

顺便说一句:通过用广播操作替换 stack,您可以将 NumPy 和 JAX 版本的速度提高 40-50%;例如:

def funcN(x):
x = x[:, None, :]
i = np.arange(200)[:, None]
return x**(i/10) * np.exp(-x)

关于python - JAX/JIT 与标准 Numpy 性能 : where I am wrong?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/69129840/

26 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com