gpt4 book ai didi

python - Python 中更快的 kNN 算法

转载 作者:太空宇宙 更新时间:2023-11-03 13:58:55 24 4
gpt4 key购买 nike

我想从头开始编写自己的 kNN 算法,原因是我需要对特征进行加权。问题是尽管删除了 for 循环并使用了内置的 numpy 功能,但我的程序仍然非常慢。

任何人都可以建议一种加快速度的方法吗?我不使用 np.sqrt 作为 L2 距离,因为它没有必要,而且实际上会减慢速度。

class GlobalWeightedKNN:
"""
A k-NN classifier with feature weights

Returns: predictions of k-NN.
"""

def __init__(self):
self.X_train = None
self.y_train = None
self.k = None
self.weights = None
self.predictions = list()

def fit(self, X_train, y_train, k, weights):
self.X_train = X_train
self.y_train = y_train
self.k = k
self.weights = weights

def predict(self, testing_data):
"""
Takes a 2d array of query cases.

Returns a list of predictions for k-NN classifier
"""

np.fromiter((self.__helper(qc) for qc in testing_data), float)
return self.predictions


def __helper(self, qc):
neighbours = np.fromiter((self.__weighted_euclidean(qc, x) for x in self.X_train), float)
neighbours = np.array([neighbours]).T
indexes = np.array([range(len(self.X_train))]).T
neighbours = np.append(indexes, neighbours, axis=1)

# Sort by second column - distances
neighbours = neighbours[neighbours[:,1].argsort()]
k_cases = neighbours[ :self.k]
indexes = [x[0] for x in k_cases]

y_answers = [self.y_train[int(x)] for x in indexes]
answer = max(set(y_answers), key=y_answers.count) # get most common value
self.predictions.append(answer)


def __weighted_euclidean(self, qc, other):
"""
Custom weighted euclidean distance

returns: floating point number
"""

return np.sum( ((qc - other)**2) * self.weights )

最佳答案

Scikit-learn 使用 KD 树或球树在 O[N log(N)] 时间内计算最近邻。您的算法是一种直接方法,需要 O[N^2] 时间,并且还在 Python 生成器表达式中使用嵌套 for 循环,与优化代码相比,这将增加显着的计算开销。

如果您想使用快速 O[N log(N)] 实现来计算加权 k-neighbors 分类,您可以使用 sklearn.neighbors.KNeighborsClassifier使用加权的 minkowski 度量,设置 p=2(对于欧氏距离)并将 w 设置为您想要的权重。例如:

from sklearn.neighbors import KNeighborsClassifier

model = KNeighborsClassifier(metric='wminkowski', p=2,
metric_params=dict(w=weights))
model.fit(X_train, y_train)
y_predicted = model.predict(X_test)

关于python - Python 中更快的 kNN 算法,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/51688568/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com