gpt4 book ai didi

python - 使用Python+NumPy+Numba时如何加速数组访问

转载 作者:行者123 更新时间:2023-12-01 04:38:22 25 4
gpt4 key购买 nike

我正在尝试使用 Numba,看看我可以多快地编写 Python+NumPy 代码。我的测试函数计算三维空间中 n 个点的成对欧几里得距离。我使用 Numba 获得了 2 个数量级的加速。如果我注释掉在数组中存储距离的行(即 distance[i, j] = ddistance[j, i] = d),我使用 Numba 获得 6 个数量级的加速。所以基本上,计算速度快如闪电,但访问保存结果的数组却很慢。有没有办法加速数组访问?

NumPy 和 Numba 函数

import numpy as np
from numba import jit, float64, void

def pairwise_distance_numpy(distance, point):
numPoints = point.shape[0]
for i in range(numPoints):
for j in range(0, i):

d = 0.0
for k in range(3):
tmp = point[i, k] - point[j, k]
d += tmp*tmp
d = d**0.5

distance[i, j] = d
distance[j, i] = d

pairwise_distance_numba = jit(void(float64[:,:], float64[:,:]), nopython=True)(pairwise_distance_numpy)

基准脚本

import numpy as np
from time import time
from pairwise_distance import pairwise_distance_numpy as pd_numpy
from pairwise_distance import pairwise_distance_numba as pd_numba

n = 1000
point = np.random.rand(n, 3)
distance = np.empty([n, n], dtype=np.float64)

pd_numpy(distance, point)
t = time()
pd_numpy(distance, point)
dt_numpy = time() - t
print('Numpy elapsed time: ', dt_numpy)

pd_numba(distance, point)
t = time()
pd_numba(distance, point)
dt_numba = time() - t
print('Numba Elapsed time: ', dt_numba)

print('Numba speedup: ', dt_numpy/dt_numba)

最佳答案

Numba 似乎只是优化了计算,因为您没有将结果存储在变量中。 (来自您的代码+您的评论证实了这一点)在大多数情况下,numpy 中的数组访问应该相当快!

关于python - 使用Python+NumPy+Numba时如何加速数组访问,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/31314426/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com