gpt4 book ai didi

python - 加速简单的距离计算

转载 作者:塔克拉玛干 更新时间:2023-11-03 03:37:38 24 4
gpt4 key购买 nike

我正在实现一个简单的代码来计算 list_A 中的点 (x_a, y_a) 和所有点 (x_b, y_b) in list_B 并返回找到的最小距离。对 list_A 中的所有点重复此操作。

我的代码的MWE:

# list_A points defined in array.
list_A = np.array([
[x_data_a, # x
y_data_a] # y
], dtype=float)

# list_B points defined in list.
list_B = [[x_data_b], [y_data_b]]

# Iterate through all data points in list_A
for ind, x_a in enumerate(list_A[0][0]):
y_a = list_A[0][1][ind]

# Iterate through all points in list_B.
dist_min = 1000.
for ind2, x_b in enumerate(list_B[0]):
y_b = list_B[1][ind2]
# Find distance between points.
dist = (x_a-x_b)**2 + (y_a-y_b)**2
if dist < dist_min:
# Update value of min distance.
dist_min = dist

print 'Min dist to (', x_a, y_a, '): ', dist_min

数据格式如下:

list_A = [[[1.2 2.3 1.5 2.3 5.8 4.6 9.1] [2.5 1.0 4.6 2.4 7.4 1.1 3.2]]]

list_B = [[1.4, 5.8, 7.9], [6.1, 1.2, 3.7]]

对于大列表/数组,这可能需要相当长的时间才能完成。这可以加快速度吗?

最佳答案

运行你的代码我得到以下信息:

Min dist to ( 1.2 2.5 ):  13.0
Min dist to ( 2.3 1.0 ): 12.29
Min dist to ( 1.5 4.6 ): 2.26
Min dist to ( 2.3 2.4 ): 13.69
Min dist to ( 5.8 7.4 ): 18.1
Min dist to ( 4.6 1.1 ): 1.45
Min dist to ( 9.1 3.2 ): 1.69

将您的数组转换为以下 Nx2 数组:

a
[[ 1.2 2.5]
[ 2.3 1. ]
[ 1.5 4.6]
[ 2.3 2.4]
[ 5.8 7.4]
[ 4.6 1.1]
[ 9.1 3.2]]

b
[[ 1.4 6.1]
[ 5.8 1.2]
[ 7.9 3.7]]

现在以下应该可以工作了:

import scipy.spatial.distance as spdist

dist_arr = spdist.cdist(a,b)

print dist_arr**2
[[ 13. 22.85 46.33]
[ 26.82 12.29 38.65]
[ 2.26 30.05 41.77]
[ 14.5 13.69 33.05]
[ 21.05 38.44 18.1 ]
[ 35.24 1.45 17.65]
[ 67.7 14.89 1.69]]

ind = np.argmin(dist_arr,axis=1)

print ind
[0 1 0 1 2 1 2]

print dist_arr[np.arange(ind.shape[0]),ind]**2
[ 13. 12.29 2.26 13.69 18.1 1.45 1.69]

如果 ab 是 2X5000,则需要 ~.3 秒,而原始代码需要 ~135 秒。加速 450 倍。

关于python - 加速简单的距离计算,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/18947912/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com