gpt4 book ai didi

python - bool 数组上的 numpy.bitwise_and 的 numba 速度较慢

转载 作者:行者123 更新时间:2023-12-01 21:16:00 24 4
gpt4 key购买 nike

我正在这段代码片段中尝试 numba

from numba import jit
import numpy as np
from time import time
db = np.array(np.random.randint(2, size=(400e3, 4)), dtype=bool)
out = np.zeros((int(400e3), 1))

@jit()
def check_mask(db, out, mask=[1, 0, 1]):
for idx, line in enumerate(db):
target, vector = line[0], line[1:]
if (mask == np.bitwise_and(mask, vector)).all():
if target == 1:
out[idx] = 1
return out

st = time()
res = check_mask(db, out, [1, 0, 1])
print 'with jit: {:.4} sec'.format(time() - st)

使用 numba @jit() 装饰器,此代码运行速度变慢!

  • 无 jit:3.16 秒
  • 使用 jit:3.81 秒

只是为了帮助更好地理解这段代码的目的:

db = np.array([           # out value for mask = [1, 0, 1]
# target, vector #
[1, 1, 0, 1], # 1
[0, 1, 1, 1], # 0 (fit to mask but target == 0)
[0, 0, 1, 0], # 0
[1, 1, 0, 1], # 1
[0, 1, 1, 0], # 0
[1, 0, 0, 0], # 0
])

最佳答案

Numba 对于 jit 有两种编译模式:nopython 模式和对象模式。 Nopython 模式(默认)仅支持一组有限的 Python 和 Numpy 功能,请参阅 the docs for your version 。如果 jitted 函数包含不受支持的代码,Numba 必须退回到对象模式,这要慢得多。

我不确定 objcet 模式是否应该比纯 Python 提供加速,但无论如何你总是想使用 nopython 模式。为了确保使用 nopython 模式,请指定 nopython=True 并坚持使用非常基本的代码(经验法则:写出所有循环并仅使用标量和 Numpy 数组):

@jit(nopython=True)
def check_mask_2(db, out, mask=np.array([1, 0, 1])):
for idx in range(db.shape[0]):
if db[idx,0] != 1:
continue
check = 1
for j in range(db.shape[1]):
if mask[j] and not db[idx,j+1]:
check = 0
break
out[idx] = check
return out

显式写出内部循环还有一个优点,即我们可以在条件失败时立即跳出它。

时间安排:

%time _ = check_mask(db, out, np.array([1, 0, 1]))
# Wall time: 1.91 s
%time _ = check_mask_2(db, out, np.array([1, 0, 1]))
# Wall time: 310 ms # slow because of compilation
%time _ = check_mask_2(db, out, np.array([1, 0, 1]))
# Wall time: 3 ms

顺便说一句,该函数也可以轻松地使用 Numpy 进行矢量化,这提供了不错的速度:

def check_mask_vectorized(db, mask=[1, 0, 1]):
check = (db[:,1:] == mask).all(axis=1)
out = (db[:,0] == 1) & check
return out

%time _ = check_mask_vectorized(db, [1, 0, 1])
# Wall time: 14 ms

关于python - bool 数组上的 numpy.bitwise_and 的 numba 速度较慢,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/34500913/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com