gpt4 book ai didi

python - 如何使这个 KNN 代码在 google colab 或任何其他基于 ipython 的环境中更快?

转载 作者:太空宇宙 更新时间:2023-11-03 20:14:07 25 4
gpt4 key购买 nike

我正在使用谷歌合作实验室对 DonorsChoose 数据集进行 KNN 分类。当我对 avgw2v 和 tfidf 数据集应用 KNeighbors 分类器时,执行以下代码大约需要 4 小时。

我已经尝试在 Kaggle 笔记本上运行它,但问题仍然存在。

import matplotlib.pyplot as plt
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import roc_auc_score
train_auc_set3 = []
cv_auc_set3 = []
K = [51, 101]
for i in tqdm(K):
neigh = KNeighborsClassifier(n_neighbors=i, n_jobs=-1)
neigh.fit(X_tr_set3, y_train)

y_train_set3_pred = batch_predict(neigh, X_tr_set3)
y_cv_set3_pred = batch_predict(neigh, X_cr_set3)
train_auc_set3.append(roc_auc_score(y_train,y_train_set3_pred))
cv_auc_set3.append(roc_auc_score(y_cv, y_cv_set3_pred))

plt.plot(K, train_auc_set3, label='Train AUC')
plt.plot(K, cv_auc_set3, label='CV AUC')

plt.scatter(K, train_auc_set3, label='Train AUC points')
plt.scatter(K, cv_auc_set3, label='CV AUC points')

plt.legend()
plt.xlabel("K: hyperparameter")
plt.ylabel("AUC")
plt.title("ERROR PLOTS")
plt.grid()
plt.show()

最佳答案

这可能本质上很慢。我对这个数据集不是很熟悉,但在 Kaggle 上看了一下,它看起来包含超过 400 万个数据点。来自 KNN 上的 sklearn 页面:

For each iteration, time complexity is O(n_components x n_samples >x min(n_samples, n_features)).

还请记住,对于大型数据集,knn 必须测量给定数据点与训练集中所有数据点之间的距离才能进行预测,这在计算上是昂贵的。

对于非常大的数据集使用 k 上的大量数字可能会得到非常差的性能。我可能会做的是:

  1. 查看用单个 k 值拟合 knn 以及用单个 k 值对训练集进行预测需要多长时间。如果需要很长时间,那么我怀疑这就是你的问题。

不幸的是,有时对于非常大的数据集,我们在选择算法时会受到我们可能想要使用的算法的时间复杂度的限制。例如,核岭回归是一种很好的算法,但它不能很好地扩展到大型数据集,因为它的时间复杂度为 O(N^3)。

关于python - 如何使这个 KNN 代码在 google colab 或任何其他基于 ipython 的环境中更快?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/58564316/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com