- html - 出于某种原因,IE8 对我的 Sass 文件中继承的 html5 CSS 不友好?
- JMeter 在响应断言中使用 span 标签的问题
- html - 在 :hover and :active? 上具有不同效果的 CSS 动画
- html - 相对于居中的 html 内容固定的 CSS 重复背景?
jax.numpy.split
可用于将数组分割成等长的段,余数在最后一个元素中。例如将 5000 个元素的数组拆分为 10 个片段:
array = jnp.ones(5000)
segment_size = 10
split_indices = jnp.arange(segment_size, array.shape[0], segment_size)
segments = jnp.split(array, split_indices)
这需要大约 10 秒才能在 Google Colab 和我的本地计算机上执行。 对于一个小阵列上的如此简单的任务来说,这似乎是不合理的。我做错了什么让这变慢了吗?
对 .split
的后续调用非常快,提供了相同形状和相同拆分索引的数组。例如以下循环的第一次迭代非常慢,但其他所有迭代都很快。 (11 秒对 40 毫秒)
from timeit import default_timer as timer
import jax.numpy as jnp
array = jnp.ones(5000)
segment_size = 10
split_indices = jnp.arange(segment_size, array.shape[0], segment_size)
for k in range(5):
start = timer()
segments = jnp.split(array, split_indices)
end = timer()
print(f'call {k}: {end - start:0.2f} s')
输出:
call 0: 11.79 s
call 1: 0.04 s
call 2: 0.04 s
call 3: 0.05 s
call 4: 0.04 s
我假设后续调用会更快,因为 JAX 正在为每个参数组合缓存 split
的 jitted 版本。如果是这种情况,那么我假设 split
很慢(在第一次这样的调用中),因为编译开销。
这是真的吗?如果是,我应该如何应该在不影响性能的情况下拆分 JAX 数组?
最佳答案
这很慢,因为在 split()
的实现中存在权衡,而您的函数恰好在权衡的错误方面。
在 XLA 中有多种计算切片的方法,包括 XLA:Slice (即 lax.slice
),XLA:DynamicSlice (即 lax.dynamic_slice
)和 XLA:Gather (即 lax.gather
)。
这些之间的主要区别在于开始和结束索引是静态的还是动态的。静态索引本质上意味着您要专门针对特定索引值进行计算:这会在第一次调用时产生一些小的编译开销,但后续调用可能会非常快。另一方面,动态索引不包括这种专门化,因此编译开销较小,但每次执行所需的时间稍长。你或许能猜到这是怎么回事……
jnp.split
目前是根据 lax.slice
( see code ) 实现的,这意味着它使用静态索引。这意味着第一次使用 jnp.split
将产生与输出数量成正比的编译成本,但重复调用将执行得非常快。这似乎是 split
常见用途的最佳方法,其中会生成少量数组。
在您的情况下,您正在生成数百个数组,因此编译成本远远高于执行。
为了说明这一点,以下是基于 gather
、slice
和 dynamic_slice
的三种相同数组拆分方法的一些时间安排。如果您的程序受益于不同的实现,您可能希望直接使用其中之一,而不是使用 jnp.split
:
from timeit import default_timer as timer
from jax import lax
import jax.numpy as jnp
import jax
def f_slice(x, step=10):
return [lax.slice(x, (N,), (N + step,)) for N in range(0, x.shape[0], step)]
def f_dynamic_slice(x, step=10):
return [lax.dynamic_slice(x, (N,), (step,)) for N in range(0, x.shape[0], step)]
def f_gather(x, step=10):
step = jnp.asarray(step)
return [x[N: N + step] for N in range(0, x.shape[0], step)]
def time(f, x):
print(f.__name__)
for k in range(5):
start = timer()
segments = jax.block_until_ready(f(x))
end = timer()
print(f' call {k}: {end - start:0.2f} s')
x = jnp.ones(5000)
time(f_slice, x)
time(f_dynamic_slice, x)
time(f_gather, x)
这是 Colab CPU 运行时的输出:
f_slice
call 0: 7.78 s
call 1: 0.05 s
call 2: 0.04 s
call 3: 0.04 s
call 4: 0.04 s
f_dynamic_slice
call 0: 0.15 s
call 1: 0.12 s
call 2: 0.14 s
call 3: 0.13 s
call 4: 0.16 s
f_gather
call 0: 0.55 s
call 1: 0.54 s
call 2: 0.51 s
call 3: 0.58 s
call 4: 0.59 s
您可以在此处看到静态索引 (lax.slice
) 导致编译后执行最快。但是,为了生成许多切片,dynamic_slice
和 gather
避免了重复编译。可能我们应该根据 dynamic_slice
重新实现 jnp.split
,但这不会没有权衡:例如,它会导致(可能更常见?)很少拆分的情况,其中 lax.slice
在初始和后续运行中都会更快。此外,dynamic_slice
仅在每个切片大小相同时才避免重新编译,因此生成许多不同大小的切片会产生类似于 lax.slice
的大量编译开销。
这些权衡在 JAX 开发 channel 中得到了积极讨论;在 PR #12219 中可以找到一个与此非常相似的最新示例。 .如果您想就这个特定问题发表意见,我会邀请您提交 new jax issue主题。
最后一点:如果您真的只是对生成数组的等长连续切片感兴趣,那么调用 reshape
会更好:
out = x.reshape(len(x) // 10, 10)
结果现在是一个二维数组,其中每一行对应于上述函数的一个切片,这将远远优于任何生成数组切片列表的方法。
关于python - 为什么 JAX 的 `split()` 第一次调用这么慢?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/74199437/
在 R 中,您可以使用 strsplit在分隔符( split )上分割向量的函数如下: x <- "What is this? It's an onion. What! That's| Well
我的 .split(); 方法有问题。 我称这个函数为: get_content_ajax("html/settings.html", "#ajax", 1, "Settings page have
我是Elixir的新手。我正在尝试对字符串split的基本操作,如下所示 String.split("Awesome",""); 根据elixir document,它应该根据提供的模式split字符
当我使用 =arrayformula(split(input!G2:G, ",")) 时,为什么拆分公式没有扩展到整个列? 我只得到输入的结果!G2 单元格,而不是 G 列中的其余部分。其他公式如 =
我正在尝试制作一个名为 core-splitter 的元素,该元素在 1.0 中已弃用,因为它在我们的项目中起着关键作用。 如果您不知道 core-splitter 的作用,我可以提供一个简短的描述。
我很难尝试使用多个定界符将字符串拆分为列表。我可以像下面这样将它拆分两次: myString.split(':')[1].split('.') 然而,这看起来很不优雅。在我的脑海里,我想做这样的事情:
来自使用惊喜模块的推荐引擎的代码,我在任何地方都找不到答案。 最佳答案 根据您的目标,您可以使用 cross_validation方法,它将自动为您执行拆分。示例:cross_validate(alg
我正在制作一个有丝 split 模拟器,我希望它在细胞足够大并 split 时运行有丝 split 功能。当它分割时,我希望它能够将分割从初始 x 值(前一个单元格的 x)动画化为新的 x 值(右侧的
我有一个用于三个按钮的点击处理程序,在这个处理程序中我想提取所点击按钮的 ID。我有一行这样的代码: $('#switch button').click(function(){ var cla
我需要像这样分割一个字符串 var val = "$cs+55+mod($a)"; 放入数组 arr = val.split( /[+-/*()\s*]/ ); 问题是将分隔符保留为数组元素,如 ar
我在同一个 string 上使用 split() 和 split("") .但为什么 split("") 返回的元素数量少于 split()?我想知道在什么特定的输入情况下会发生这种情况。 最佳答案
我的代码中某处有错误,但看不到我做错了什么。 我拥有的是 facebook 用户 ID 的隐藏输入,它是通过 jQuery UI 自动完成填充的: 然后,我有一个 jQuery 函数,当单击链接将其
我正在寻找一个程序来读取字符串/文件并显示其中的前三个单词。 所以我尝试了: letter= "a,b,c" print(letter.split(',')[0]) 这对获取一个单词有效,但执行 [0
我有一个存储邮件的表 Mails(谁会想到... ;))。 通过 tinyint MailStatus,我决定这是 SentMail、Draft 还是 ReceivedMail。 现在我想知道 Tab
在我的优化探索中,我发现内置的 split() 方法比等效的 re.split() 方法快大约 40%。 虚拟基准(易于复制粘贴): import re, time, random def rando
我对split有一个奇怪的问题,因为默认情况下它不会将split放入默认数组中。 以下是一些玩具代码。 #!/usr/bin/perl $A="A:B:C:D"; split (":",$A); pr
我目前正在学习 JCL,并且正在使用 SORT 程序。作为练习,我想将一些输入记录拆分为属于同一 PDS 的多个成员。这是我的 JCL 代码: //FAILJ JOB //STEP1 EX
在苦苦挣扎了半小时之后,我在使用空格分割字符串时遇到了这种差异,具体取决于您使用的语法。 简单字符串: $line = "1: 2: 3: 4: 5: " 拆分示例1 -从1开始注意带有 token
我有一个像这样的字符串: 'Agendas / Schedules meetings and speakers 4 F 1928-1209 Box 2' 我正在尝试将其
我试图了解 r-tree 的工作原理,发现有两种类型的拆分:二次拆分和线性拆分。 线性和二次实际上有什么区别?在哪种情况下,一个会比另一个更受欢迎? 最佳答案 原始 R-Tree 论文在 3.5.2
我是一名优秀的程序员,十分优秀!