gpt4 book ai didi

python - 如何优化这段代码以进行 nn 预测?

转载 作者:行者123 更新时间:2023-11-30 22:49:45 25 4
gpt4 key购买 nike

如何优化这段代码?目前,它的运行速度因经过此循环的数据量而变慢。此代码运行 1-最近邻。它将根据p_data_set预测training_element的标签

#               [x] ,           [[x1],[x2],[x3]],    [l1, l2, l3]
def prediction(training_element, p_data_set, p_label_set):
temp = np.array([], dtype=float)
for p in p_data_set:
temp = np.append(temp, distance.euclidean(training_element, p))

minIndex = np.argmin(temp)
return p_label_set[minIndex]

最佳答案

使用 k-D tree用于快速最近邻查找,例如scipy.spatial.cKDTree :

from scipy.spatial import cKDTree

# I assume that p_data_set is (nsamples, ndims)
tree = cKDTree(p_data_set)

# training_elements is also assumed to be (nsamples, ndims)
dist, idx = tree.query(training_elements, k=1)

predicted_labels = p_label_set[idx]

关于python - 如何优化这段代码以进行 nn 预测?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/39625552/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com