gpt4 book ai didi

python - 使用 NumPY 改进 KNN 线性搜索效率

转载 作者:行者123 更新时间:2023-11-30 09:29:25 24 4
gpt4 key购买 nike

我正在尝试计算测试集中每个点与训练集中每个点的距离:

这就是我的循环现在的样子:

 for x in testingSet
for y in trainingSet
print numpy.linalg.norm(x-y)

其中testingSet和trainingSet是numpy数组,其中两个集合的每一行保存一个示例的特征数据。

但是,它的运行速度非常慢,需要 10 多分钟,因为我的数据集较大(测试集为 3000,训练集约为 10,000)。这与我的方法有关还是我错误地使用了 numPY?

最佳答案

这是因为你天真地迭代了数据,而 Python 中的循环速度很慢。相反,使用 sklearn pairwise distance functions ,或者更好 - 使用 sklearn efficient nearest neighbour搜索(如 BallTree 或 KDTree)。如果你不想用sklearn,还有一个module in scipy 。最后你可以用“矩阵技巧”来计算这个,因为

|| x - y ||^2 = <x-y, x-y> = <x,x> + <y,y> - 2<x,y>

你可以这样做(假设你的数据是以矩阵形式给出的 X 和 Y):

X2 = (X**2).sum(axis=1).reshape((-1, 1))
Y2 = (Y**2).sum(axis=1).reshape((1, -1))
distances = np.sqrt(X2 + Y2 - 2*X.dot(Y.T))

关于python - 使用 NumPY 改进 KNN 线性搜索效率,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/39457604/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com