gpt4 book ai didi

python - 在 jitted 函数中两次反转 numpy 数组的 View 使函数运行得更快

转载 作者:太空狗 更新时间:2023-10-30 01:24:29 25 4
gpt4 key购买 nike

所以我正在测试同一功能的两个版本的速度;一种是两次反转 numpy 数组的 View ,另一种是没有。代码如下:

import numpy as np
from numba import njit

@njit
def min_getter(arr):

if len(arr) > 1:
result = np.empty(len(arr), dtype = arr.dtype)
local_min = arr[0]
result[0] = local_min

for i in range(1,len(arr)):
if arr[i] < local_min:
local_min = arr[i]
result[i] = local_min
return result

else:
return arr

@njit
def min_getter_rev1(arr1):

if len(arr1) > 1:
arr = arr1[::-1][::-1]
result = np.empty(len(arr), dtype = arr.dtype)
local_min = arr[0]
result[0] = local_min

for i in range(1,len(arr)):
if arr[i] < local_min:
local_min = arr[i]
result[i] = local_min
return result

else:
return arr1
size = 500000
x = np.arange(size)
y = np.hstack((x[::-1], x))

y_min = min_getter(y)
yrev_min = min_getter_rev1(y)

令人惊讶的是,带有额外操作的那个在多个场合运行得稍微快一些。我用了%timeit两个功能大约 10 次;尝试了不同大小的数组,差异很明显(至少在我的电脑上是这样)。 min_getter 的运行时间在附近:

2.35 ms ± 58.3 µs per loop (mean ± std. dev. of 7 runs, 100 loops each) (有时是2.33,有时是2.37,但从不低于2.30)

min_getter_rev1 的运行时间在附近:

2.22 ms ± 23.2 µs per loop (mean ± std. dev. of 7 runs, 100 loops each) (有时是2.25,有时是2.23,但很少超过2.30)


关于为什么以及如何发生的任何想法?速度差异大约增加了 4-6%,这在某些应用程序中可能是一个大问题。加速的底层机制可能有助于加速一些 jitted 代码


注1:我试过size=5000000,每个函数都测试了5-10次,差异更加明显。较快的运行在 23.2 ms ± 51.7 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)较慢的是 24.4 ms ± 234 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

注2:numpy的版本和 numba在测试期间是 1.16.50.45.1 ; python 版本是 3.7.4 ; IPython版本是 7.8.0 ;使用的 Python IDE 是 spyder .不同版本测试结果可能不同。

最佳答案

TL;DR:第二个代码更快可能只是一个幸运的巧合。


检查生成的类型揭示了一个重要的区别:

  • 在第一个示例中,您的 arr 被键入为 array(int32, 1d, C) 一个 C 连续数组。
min_getter.inspect_types()

min_getter (array(int32, 1d, C),) <--- THIS IS THE IMPORTANT LINE
--------------------------------------------------------------------------------
# File: <>
# --- LINE 4 ---
# label 0

@njit

# --- LINE 5 ---

def min_getter(arr):

[...]
  • 在第二个示例中,arr 被键入为 array(int32, 1d, A),这是一个不知道是否连续的数组。这是因为 [::-1] 返回一个没有连续性信息的数组,一旦丢失,第二个 [::-1] 就无法恢复。<
>>> min_getter_rev1.inspect_types()

[...]

# --- LINE 18 ---
# arr1 = arg(0, name=arr1) :: array(int32, 1d, C)
# $const0.2 = const(NoneType, None) :: none
# $const0.3 = const(NoneType, None) :: none
# $const0.4 = const(int, -1) :: Literal[int](-1)
# $0.5 = global(slice: <class 'slice'>) :: Function(<class 'slice'>)
# $0.6 = call $0.5($const0.2, $const0.3, $const0.4, func=$0.5, args=(Var($const0.2, <> (18)), Var($const0.3, <> (18)), Var($const0.4, <> (18))), kws=(), vararg=None) :: (none, none, int64) -> slice<a:b:c>
# del $const0.4
# del $const0.3
# del $const0.2
# del $0.5
# $0.7 = static_getitem(value=arr1, index=slice(None, None, -1), index_var=$0.6) :: array(int32, 1d, A)
# del arr1
# del $0.6
# $const0.8 = const(NoneType, None) :: none
# $const0.9 = const(NoneType, None) :: none
# $const0.10 = const(int, -1) :: Literal[int](-1)
# $0.11 = global(slice: <class 'slice'>) :: Function(<class 'slice'>)
# $0.12 = call $0.11($const0.8, $const0.9, $const0.10, func=$0.11, args=(Var($const0.8, <> (18)), Var($const0.9, <> (18)), Var($const0.10, <> (18))), kws=(), vararg=None) :: (none, none, int64) -> slice<a:b:c>
# del $const0.9
# del $const0.8
# del $const0.10
# del $0.11
# $0.13 = static_getitem(value=$0.7, index=slice(None, None, -1), index_var=$0.12) :: array(int32, 1d, A)
# del $0.7
# del $0.12
# arr = $0.13 :: array(int32, 1d, A) <---- THIS IS THE IMPORTANT LINE
# del $0.13

arr = arr1[::-1][::-1]

[...]

(其余生成的代码几乎相同)

如果已知数组是连续的,索引和迭代应该会更快。但这不是我们在这种情况下观察到的情况 - 恰恰相反。

那么可能是什么原因呢?

Numba 本身使用 LLVM 来“编译” jitted 代码。所以有一个实际的编译器参与,编译器可以进行优化。尽管 inspect_types() 检查的代码几乎相同,但实际的 LLVM/ASM 代码与 inspect_llvm()inspect_asm() 大不相同。因此,编译器(或 numba)能够在第二种情况下进行某种优化,这在第一种情况下是不可能的。或者应用于第一种情况的某些优化实际上使代码变得更糟。

然而,这意味着我们在第二种情况下只是“走运”。这可能不是可以控制的,因为它取决于:

  • numba 根据您的来源创建的类型,
  • numba 内部使用的对这些类型进行操作的源代码
  • 从这些类型和 numba 源代码生成的 LLVM 和
  • 从该 LLVM 生成的 ASM。

有太多可以应用优化(或不应用优化)的移动部件。


有趣的事实:如果你扔掉外面的 ifs:

import numpy as np
from numba import njit

@njit
def min_getter(arr):
result = np.empty(len(arr), dtype = arr.dtype)
local_min = arr[0]
result[0] = local_min

for i in range(1,len(arr)):
if arr[i] < local_min:
local_min = arr[i]
result[i] = local_min
return result

@njit
def min_getter_rev1(arr1):
arr = arr1[::-1][::-1]
result = np.empty(len(arr), dtype = arr.dtype)
local_min = arr[0]
result[0] = local_min

for i in range(1,len(arr)):
if arr[i] < local_min:
local_min = arr[i]
result[i] = local_min
return result

size = 500000
x = np.arange(size)
y = np.hstack((x[::-1], x))

y_min = min_getter(y)
yrev_min = min_getter_rev1(y)

%timeit min_getter(y) # 2.29 ms ± 86.9 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
%timeit min_getter_rev1(y) # 2.37 ms ± 212 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

在那种情况下,没有 [::-1][::-1] 的速度更快。

因此,如果你想让它可靠地更快:移动 if len(arr) > 1 检查函数外部并且不要使用 [::-1][::-1] 因为在大多数情况下,这会使函数运行得更慢(并且可读性更差)!

关于python - 在 jitted 函数中两次反转 numpy 数组的 View 使函数运行得更快,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/58039192/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com