gpt4 book ai didi

performance - 预测速度慢: Scikit Gaussian Process classification

转载 作者:行者123 更新时间:2023-11-30 09:27:11 26 4
gpt4 key购买 nike

以下代码用于评估训练高斯过程 (GP) 并用于对 MNIST 数据集中的图像进行分类。

import numpy as np


from sklearn.metrics.classification import accuracy_score, log_loss
from sklearn.gaussian_process import GaussianProcessClassifier
from sklearn.gaussian_process.kernels import RBF
from sklearn import datasets
from sklearn.datasets import fetch_mldata

import random

SAMPLE_SIZE = 2000

def build_and_optimize(hp_space):
build_train()
return

def build_train(hp_space):
l_scale = hp_space['l_scale']
bias = hp_space['bias']
gp_fix = GaussianProcessClassifier(kernel=bias * RBF(length_scale=l_scale), optimizer=None, multi_class='one_vs_rest')

X_train, X_test, y_train, y_test = prepare_data()


gp_fix.fit(X_train, y_train)

print("Log Marginal Likelihood (initial): %.3f"
% gp_fix.log_marginal_likelihood(gp_fix.kernel_.theta))


y_ = gp_fix.predict(X_test[0:500])
print(y_)
print(y_test[0:500])
print("Accuracy: %.3f (initial)"
% (accuracy_score(y_test[0:500], y_)))

return


def prepare_data():
mnist = fetch_mldata('MNIST original', data_home='./')
mnist.keys()

images = mnist.data
targets = mnist.target

X_data = images/255.0
Y = targets
shuf = random.sample(range(len(X_data)), SAMPLE_SIZE)

X_train = []
for x in shuf: X_train.append(X_data[x])

y_train = []
for x in shuf: y_train.append(Y[x])

c_shuf = set(range(len(X_data))) - set(shuf)

X_test = []
for x in c_shuf: X_test.append(X_data[x])

y_test = []
for x in c_shuf: y_test.append(Y[x])

return X_train, X_test, y_train, y_test

if __name__ == "__main__":
hp_space = {}
hp_space['l_scale'] = 1.0
hp_space['bias'] = 1.0

build_train(hp_space)

模型的训练似乎需要相当长的时间。然而,预测需要很长时间。有任何指示可能是什么原因吗?

最佳答案

你可以认为高斯过程和支持向量机是有些相似的模型,两者都使用核技巧来构建模型。与 SVM 一样,GP 需要 O(n^3) 时间进行训练,其中 n 是训练集中的数据点数量。因此,您自然应该期望它需要一段时间来训练,并且随着数据集大小的增加它会快速增长。

类似地,GP 预测每次预测需要 O(n) 时间,类似于最近邻搜索和 SVMS。然而,GP 解决方案是密集的,这意味着它使用所有训练点进行预测,而 SVM 是稀疏的,因此它会丢弃一些训练点。

关于performance - 预测速度慢: Scikit Gaussian Process classification,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/44958748/

26 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com