gpt4 book ai didi

python - 嵌套 Numpy 数组上的 Numba

转载 作者:太空宇宙 更新时间:2023-11-03 23:57:35 27 4
gpt4 key购买 nike

设置

我有以下两个矩阵计算的实现:

  1. 第一个实现使用形状为 (n, m) 的矩阵,并在 for 循环中重复计算 repetition 次:
import numpy as np
from numba import jit

@jit
def foo():
for i in range(1, n):
for j in range(1, m):

_deleteA = (
matrix[i, j] +
#some constants added here
)
_deleteB = (
matrix[i, j-1] +
#some constants added here
)
matrix[i, j] = min(_deleteA, _deleteB)

return matrix

repetition = 3
for x in range(repetition):
foo()


2. 第二种实现避免了额外的 for 循环,因此将 repetition = 3 包含到矩阵中,然后矩阵的形状为 shape (repetition, n, m):

@jit
def foo():
for i in range(1, n):
for j in range(1, m):

_deleteA = (
matrix[:, i, j] +
#some constants added here
)
_deleteB = (
matrix[:, i, j-1] +
#some constants added here
)
matrix[:, i, j] = np.amin(np.stack((_deleteA, _deleteB), axis=1), axis=1)

return matrix


问题

关于这两个实现,我在 iPython 中使用 %timeit 发现了两件关于它们性能的事情。

  1. 第一个实现从 @jit 中获利巨大,而第二个则根本没有(在我的测试用例中为 28 毫秒对 25 秒)。 有人能想象为什么 @jit 不再适用于形状为 (repetition, n, m) 的 numpy 数组吗?


编辑

我把之前的第二题移到了an extra post因为问多个问题被认为是糟糕的 SO 风格。

问题是:

  1. 当忽略 @jit 时,第一个实现仍然快很多(相同的测试用例:17 秒对 26 秒)。 为什么 numpy 在处理三维而不是二维时速度较慢?

最佳答案

我不确定你的设置是什么,但我稍微重写了你的例子:

import numpy as np
from numba import jit

#@jit(nopython=True)
def foo(matrix):
n, m = matrix.shape
for i in range(1, n):
for j in range(1, m):

_deleteA = (
matrix[i, j] #+
#some constants added here
)
_deleteB = (
matrix[i, j-1] #+
#some constants added here
)
matrix[i, j] = min(_deleteA, _deleteB)

return matrix

foo_jit = jit(nopython=True)(foo)

然后是计时:

m = np.random.normal(size=(100,50))

%timeit foo(m) # in a jupyter notebook
# 2.84 ms ± 54.2 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

%timeit foo_jit(m) # in a jupyter notebook
# 3.18 µs ± 38.9 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)

所以这里 numba 比预期的要快得多。需要考虑的一件事是,全局 numpy 数组在 numba 中的行为并不像您预期​​的那样:

https://numba.pydata.org/numba-doc/dev/user/faq.html#numba-doesn-t-seem-to-care-when-i-modify-a-global-variable

通常最好像我在示例中那样传入数据。

在第二种情况下,您的问题是 numba 目前不支持 amin。见:

https://numba.pydata.org/numba-doc/dev/reference/numpysupported.html

如果将 nopython=True 传递给 jit,就可以看到这一点。因此,在当前版本的 numba(当前为 0.44 或更早版本)中,它将回退到 objectmode,这通常并不比不使用 numba 快,而且有时会更慢,因为存在一些调用开销。

关于python - 嵌套 Numpy 数组上的 Numba,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/56992398/

27 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com