gpt4 book ai didi

python - Numba jit nopython 模式 : tell numba the signature of an external arbitrary function

转载 作者:行者123 更新时间:2023-11-28 22:24:02 27 4
gpt4 key购买 nike

我需要为通用指标构建相异矩阵。由于我需要算法快速运行,所以我在 nopython 模式下使用了 numba 0.35。这是我的代码

import numpy as np
from numba import jit
from jellyfish import levenshtein_distance

def _dissimilarity_matrix(metric):
@jit(nopython=True)
def dm(data):
n = data.shape[0]
diss = np.zeros((n, n))
for i in range(n):
for j in range(i+1):
dist = metric(data[i], data[j])
diss[i, j] = dist
diss[j, i] = dist
return diss

return dm

@jit(nopython=True)
def euclidean_distance(vec1, vec2):
return np.sqrt(((vec1 - vec2)**2).sum())

test1 = np.random.randn(10, 2)
dissimilarity_matrix1 = _dissimilarity_matrix(euclidean_distance)
diss1 = dissimilarity_matrix1(test1)

test2 = np.array(["this", "is", "a", "test"])
dissimilarity_matrix2 = _dissimilarity_matrix(levenshtein_distance)
diss2 = dissimilarity_matrix2(test2)

但是输出是:

numba.errors.TypingError: Failed at nopython (nopython frontend)
Untyped global name 'metric': cannot determine Numba type of <class 'builtin_function_or_method'>
File "test.py", line 12

注意函数euclidean_distance是我定义的,有装饰器@jit(nopython=True),而函数levenshtein_distance来了来自外部模块(不是我写的)。有没有办法明确告诉 numba 传入函数的签名(即 _dissimilarity_matrix 中的 metric)?我确实需要函数 _dissimilarity_matrixnopython 模式下运行并接受任意函数作为输入。

最佳答案

metriceuclidean_distance 时,您的代码对我有用,因为该函数也是 nopython jitted numba 函数。但是,您不能传入任意函数。为了让某些东西在 nopython 模式下工作,numba 必须支持每个可调用函数(参见 http://numba.pydata.org/numba-doc/latest/reference/pysupported.htmlhttp://numba.pydata.org/numba-doc/latest/reference/numpysupported.html )或者用户定义为 numba nopython 功能。没有办法绕过这个限制。

关于python - Numba jit nopython 模式 : tell numba the signature of an external arbitrary function,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/46696639/

27 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com