gpt4 book ai didi

python - 为什么 numpy 不会在非连续数组上短路?

转载 作者:太空狗 更新时间:2023-10-29 18:03:30 25 4
gpt4 key购买 nike

考虑以下简单测试:

import numpy as np
from timeit import timeit

a = np.random.randint(0,2,1000000,bool)

让我们找到第一个 True

的索引
timeit(lambda:a.argmax(), number=1000)
# 0.000451055821031332

这相当快,因为​​ numpy 短路。

它也适用于连续的切片,

timeit(lambda:a[1:-1].argmax(), number=1000)
# 0.0006490410305559635

但似乎不是在非连续的上。我主要对找到最后一个 True 感兴趣:

timeit(lambda:a[::-1].argmax(), number=1000)
# 0.3737605109345168

UPDATE: My assumption that the observed slowdown was due to not short circuiting is inaccurate (thanks @Victor Ruiz). Indeed, in the worst-case scenario of an all False array

b=np.zeros_like(a)
timeit(lambda:b.argmax(), number=1000)
# 0.04321779008023441

we are still an order of magnitude faster than in the non-contiguous case. I'm ready to accept Victor's explanation that the actual culprit is a copy being made (timings of forcing a copy with .copy() are suggestive). Afterwards it doesn't really matter anymore whether short-circuiting happens or not.

但其他步长 != 1 会产生类似的行为。

timeit(lambda:a[::2].argmax(), number=1000)
# 0.19192566303536296

问题:为什么 numpy 在最后两个例子中没有短路 UPDATE 没有复制

而且,更重要的是:是否有解决方法,即强制 numpy 短路 UPDATE 而不制作副本在非连续数组上?

最佳答案

问题与使用strides时数组的内存对齐有关。a[1:-1]a[::-1] 被认为在内存中对齐,但 a[::2]不要:

a = np.random.randint(0,2,1000000,bool)

print(a[1:-1].flags.c_contiguous) # True
print(a[::-1].flags.c_contiguous) # False
print(a[::2].flags.c_contiguous) # False

这解释了为什么 np.argmaxa[::2] 上很慢(来自 ndarrays 上的文档):

Several algorithms in NumPy work on arbitrarily strided arrays. However, some algorithms require single-segment arrays. When an irregularly strided array is passed in to such algorithms, a copy is automatically made.

np.argmax(a[::2]) 正在制作数组的副本。因此,如果您执行 timeit(lambda: np.argmax(a[::2]), number=5000),您将计时数组 a

执行这个并比较这两个计时调用的结果:

print(timeit(lambda: np.argmax(a[::2]), number=5000))

b = a[::2].copy()
print(timeit(lambda: np.argmax(b), number=5000))

编辑:深入研究numpy的C语言源代码,我发现了argmax函数的下划线实现,PyArray_ArgMax在某个时候调用 PyArray_ContiguousFromAny确保给定的输入数组在内存中对齐(C 风格)

然后,如果数组的 dtype 是 bool,它委托(delegate)给 BOOL_argmax功能。查看其代码,似乎始终应用了短路。

总结

  • 为了避免被np.argmax复制,确保输入数组在内存中是连续的
  • 当数据类型为 bool 值时,始终应用短路。

关于python - 为什么 numpy 不会在非连续数组上短路?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/57346182/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com