gpt4 book ai didi

python - 为什么 JAX 的 `split()` 第一次调用这么慢?

转载 作者:行者123 更新时间:2023-12-05 01:03:24 25 4
gpt4 key购买 nike

jax.numpy.split 可用于将数组分割成等长的段,余数在最后一个元素中。例如将 5000 个元素的数组拆分为 10 个片段:

array = jnp.ones(5000)
segment_size = 10
split_indices = jnp.arange(segment_size, array.shape[0], segment_size)

segments = jnp.split(array, split_indices)

这需要大约 10 秒才能在 Google Colab 和我的本地计算机上执行。 对于一个小阵列上的如此简单的任务来说,这似乎是不合理的。我做错了什么让这变慢了吗?


更多细节(JIT 缓存,也许?)

.split 的后续调用非常快,提供了相同形状和相同拆分索引的数组。例如以下循环的第一次迭代非常慢,但其他所有迭代都很快。 (11 秒对 40 毫秒)

from timeit import default_timer as timer
import jax.numpy as jnp

array = jnp.ones(5000)
segment_size = 10
split_indices = jnp.arange(segment_size, array.shape[0], segment_size)

for k in range(5):
start = timer()

segments = jnp.split(array, split_indices)

end = timer()
print(f'call {k}: {end - start:0.2f} s')

输出:

call 0: 11.79 s
call 1: 0.04 s
call 2: 0.04 s
call 3: 0.05 s
call 4: 0.04 s

我假设后续调用会更快,因为 JAX 正在为每个参数组合缓存 split 的 jitted 版本。如果是这种情况,那么我假设 split 很慢(在第一次这样的调用中),因为编译开销。

这是真的吗?如果是,我应该如何应该在不影响性能的情况下拆分 JAX 数组?

最佳答案

这很慢,因为在 split() 的实现中存在权衡,而您的函数恰好在权衡的错误方面。

在 XLA 中有多种计算切片的方法,包括 XLA:Slice (即 lax.slice),XLA:DynamicSlice (即 lax.dynamic_slice)和 XLA:Gather (即 lax.gather)。

这些之间的主要区别在于开始和结束索引是静态的还是动态的。静态索引本质上意味着您要专门针对特定索引值进行计算:这会在第一次调用时产生一些小的编译开销,但后续调用可能会非常快。另一方面,动态索引不包括这种专门化,因此编译开销较小,但每次执行所需的时间稍长。你或许能猜到这是怎么回事……

jnp.split 目前是根据 lax.slice ( see code ) 实现的,这意味着它使用静态索引。这意味着第一次使用 jnp.split 将产生与输出数量成正比的编译成本,但重复调用将执行得非常快。这似乎是 split 常见用途的最佳方法,其中会生成少量数组。

在您的情况下,您正在生成数百个数组,因此编译成本远远高于执行。

为了说明这一点,以下是基于 gatherslicedynamic_slice 的三种相同数组拆分方法的一些时间安排。如果您的程序受益于不同的实现,您可能希望直接使用其中之一,而不是使用 jnp.split:

from timeit import default_timer as timer
from jax import lax
import jax.numpy as jnp
import jax

def f_slice(x, step=10):
return [lax.slice(x, (N,), (N + step,)) for N in range(0, x.shape[0], step)]

def f_dynamic_slice(x, step=10):
return [lax.dynamic_slice(x, (N,), (step,)) for N in range(0, x.shape[0], step)]

def f_gather(x, step=10):
step = jnp.asarray(step)
return [x[N: N + step] for N in range(0, x.shape[0], step)]


def time(f, x):
print(f.__name__)
for k in range(5):
start = timer()
segments = jax.block_until_ready(f(x))
end = timer()
print(f' call {k}: {end - start:0.2f} s')

x = jnp.ones(5000)

time(f_slice, x)
time(f_dynamic_slice, x)
time(f_gather, x)

这是 Colab CPU 运行时的输出:

f_slice
call 0: 7.78 s
call 1: 0.05 s
call 2: 0.04 s
call 3: 0.04 s
call 4: 0.04 s
f_dynamic_slice
call 0: 0.15 s
call 1: 0.12 s
call 2: 0.14 s
call 3: 0.13 s
call 4: 0.16 s
f_gather
call 0: 0.55 s
call 1: 0.54 s
call 2: 0.51 s
call 3: 0.58 s
call 4: 0.59 s

您可以在此处看到静态索引 (lax.slice) 导致编译后执行最快。但是,为了生成许多切片,dynamic_slicegather 避免了重复编译。可能我们应该根据 dynamic_slice 重新实现 jnp.split,但这不会没有权衡:例如,它会导致(可能更常见?)很少拆分的情况,其中 lax.slice 在初始和后续运行中都会更快。此外,dynamic_slice 仅在每个切片大小相同时才避免重新编译,因此生成许多不同大小的切片会产生类似于 lax.slice 的大量编译开销。

这些权衡在 JAX 开发 channel 中得到了积极讨论;在 PR #12219 中可以找到一个与此非常相似的最新示例。 .如果您想就这个特定问题发表意见,我会邀请您提交 new jax issue主题。

最后一点:如果您真的只是对生成数组的等长连续切片感兴趣,那么调用 reshape 会更好:

out = x.reshape(len(x) // 10, 10)

结果现在是一个二维数组,其中每一行对应于上述函数的一个切片,这将远远优于任何生成数组切片列表的方法。

关于python - 为什么 JAX 的 `split()` 第一次调用这么慢?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/74199437/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com