gpt4 book ai didi

python - 使用 numba 优化 Jaccard 距离性能

转载 作者:太空宇宙 更新时间:2023-11-03 15:55:24 24 4
gpt4 key购买 nike

我正在尝试使用 Numba 在 python 中实现最快的 jaccard 距离版本

@nb.jit()
def nbjaccard(seq1, seq2):
set1, set2 = set(seq1), set(seq2)
return 1 - len(set1 & set2) / float(len(set1 | set2))

def jaccard(seq1, seq2):
set1, set2 = set(seq1), set(seq2)
return 1 - len(set1 & set2) / float(len(set1 | set2))


%%timeit
nbjaccard("compare this string","compare a different string")

--12.4 毫秒

%%timeit 
jaccard("compare this string","compare a different string")

--3.87 毫秒

为什么 numba 版本需要更长的时间?有什么方法可以加速吗?

最佳答案

在我看来,允许对象模式 numba 函数是一个设计错误(或者如果 numba 意识到整个函数使用 python 对象则没有警告)——因为这些是通常比纯 Python 函数慢一点。

Numba 非常强大(类型分派(dispatch)和您可以在没有类型声明的情况下编写 python 代码 - 与 C 扩展或 Cython 相比 - 真的很棒)但只有当它支持操作时:

这意味着“nopython”模式不支持任何未在此处列出的操作。如果 numba 必须退回到 "object mode"然后小心:

object mode

A Numba compilation mode that generates code that handles all values as Python objects and uses the Python C API to perform all operations on those objects. Code compiled in object mode will often run no faster than Python interpreted code, unless the Numba compiler can take advantage of loop-jitting.

这正是您的情况:您纯粹在对象模式下操作:

>>> nbjaccard.inspect_types()

[...]
# --- LINE 3 ---
# seq1 = arg(0, name=seq1) :: pyobject
# seq2 = arg(1, name=seq2) :: pyobject
# $0.1 = global(set: <class 'set'>) :: pyobject
# $0.3 = call $0.1(seq1) :: pyobject
# $0.4 = global(set: <class 'set'>) :: pyobject
# $0.6 = call $0.4(seq2) :: pyobject
# set1 = $0.3 :: pyobject
# set2 = $0.6 :: pyobject

set1, set2 = set(seq1), set(seq2)

# --- LINE 4 ---
# $const0.7 = const(int, 1) :: pyobject
# $0.8 = global(len: <built-in function len>) :: pyobject
# $0.11 = set1 & set2 :: pyobject
# $0.12 = call $0.8($0.11) :: pyobject
# $0.13 = global(float: <class 'float'>) :: pyobject
# $0.14 = global(len: <built-in function len>) :: pyobject
# $0.17 = set1 | set2 :: pyobject
# $0.18 = call $0.14($0.17) :: pyobject
# $0.19 = call $0.13($0.18) :: pyobject
# $0.20 = $0.12 / $0.19 :: pyobject
# $0.21 = $const0.7 - $0.20 :: pyobject
# $0.22 = cast(value=$0.21) :: pyobject
# return $0.22

return 1 - len(set1 & set2) / float(len(set1 | set2))

如您所见,每个操作都对 Python 对象进行操作(如每行末尾的 ::pyobject 所示)。那是因为 numba 不支持 strset。所以这里绝对没有什么比这更快的了。除非您知道如何使用 numpy 数组或同类列表(数字类型)解决此问题。

在我的电脑上,时差要大得多(使用 numba 0.32.0),但个别计时要快得多 - 秒(10**-6 秒) 而不是 毫秒秒(10**-3 秒):

%timeit nbjaccard("compare this string","compare a different string")
10000 loops, best of 3: 84.4 µs per loop

%timeit jaccard("compare this string","compare a different string")
100000 loops, best of 3: 15.9 µs per loop

请注意 jit 默认为 lazy ,因此第一次调用应该在您为执行计时之前完成 - 因为它包括编译代码的时间。


不过,您可以进行一种优化:如果您知道两个集合的交集,则可以计算并集的长度(正如@Paul Hankin 在他的现已删除 回答中提到的):

len(union) = len(set1) + len(set2) - len(intersection)

这将导致以下(纯 python)代码:

def jaccard2(seq1, seq2):
set1, set2 = set(seq1), set(seq2)
num_intersection = len(set1 & set2)
return 1 - num_intersection / float(len(set1) + len(set2) - num_intersection)

%timeit jaccard2("compare this string","compare a different string")
100000 loops, best of 3: 13.7 µs per loop

不是更快 - 但更好。


如果您使用 ,还有一些改进空间:

%load_ext cython

%%cython
def cyjaccard(seq1, seq2):
cdef set set1 = set(seq1)
cdef set set2 = set()

cdef Py_ssize_t length_intersect = 0

for char in seq2:
if char not in set2:
if char in set1:
length_intersect += 1
set2.add(char)

return 1 - (length_intersect / float(len(set1) + len(set2) - length_intersect))

%timeit cyjaccard("compare this string","compare a different string")
100000 loops, best of 3: 7.97 µs per loop

这里的主要优点是只需一次迭代就可以创建 set2 并计算交集中的元素数量(根本不需要创建交集)!

关于python - 使用 numba 优化 Jaccard 距离性能,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/43596535/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com