gpt4 book ai didi

python - 使用 for 循环时如何减少 JAX 编译时间?

转载 作者:行者123 更新时间:2023-12-05 03:38:07 25 4
gpt4 key购买 nike

这是一个基本示例。

@jax.jit
def block(arg1, arg2):
for x1 in range(cons1):
for x2 in range(cons2):
for x3 in range(cons3):
--do something--
return result

当 cons 较小时,编译时间约为一分钟。缺点越大,编译时间就越长——几十分钟。我需要更高的缺点。可以做什么?根据我正在阅读的内容,循环是原因。它们在编译时展开。有什么解决方法吗?还有jax.fori_loop。但我不明白如何使用它。有 jax.experimental.loops 模块,但我还是无法理解它。

我对这一切都很陌生。因此,感谢所有帮助。如果您能提供一些如何使用 jax 循环的示例,我们将不胜感激。

另外,什么是好的编译时间?可以在几分钟内完成吗?在其中一个示例中,编译时间为 262 秒,剩余运行时间约为 0.1-0.2 秒。

运行时的任何 yield 都被编译时间所掩盖。

最佳答案

JAX 的 JIT 编译器扁平化了所有 Python 循环。要明白我的意思,请看一下通过 jax.make_jaxpr 运行的这个简单函数,这是一种检查 JAX 的跟踪器如何解释 python 代码的方法(有关更多信息,请参见 Understanding Jaxprs):

import jax

def f(x):
for i in range(5):
x += i
return x

print(jax.make_jaxpr(f)(0))
# { lambda ; a.
# let b = add a 0
# c = add b 1
# d = add c 2
# e = add d 3
# f = add e 4
# in (f,) }

请注意循环是扁平化的:每一步都成为发送到 XLA 编译器的显式操作。 XLA 编译时间随着函数中操作次数的增加而增加,因此三重嵌套的 for 循环会导致较长的编译时间是有道理的。

那么,如何解决这个问题呢?好吧,很遗憾,答案取决于您的 --do something-- 正在做什么,所以我无法猜测。

通常,最好的选择是使用向量化数组操作,而不是循环这些向量中的值;例如,这是一个非常慢的添加两个向量的方法:

import jax.numpy as jnp

def f_slow(x, y):
z = []
for xi, yi in zip(xi, yi):
z.append(xi + yi)
return jnp.array(z)

这里有一个更快的方法来做同样的事情:

def f_fast(x, y):
return x + y

如果您的操作不适合矢量化,另一种选择是使用 lax control flow 运算符代替 for 循环:这会将循环向下插入 XLA。这在 CPU 上具有相当好的性能,但与等效的矢量化数组操作相比,在加速器上速度较慢。

有关 JAX 和 Python 控制流语句(例如 forifwhile 等)的更多讨论,请参阅 🔪 JAX - The Sharp Bits 🔪: Control Flow

关于python - 使用 for 循环时如何减少 JAX 编译时间?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/69070804/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com