gpt4 book ai didi

python - Scikit-learn 中的受限玻尔兹曼机 : Iris Classification

转载 作者:行者123 更新时间:2023-11-30 08:55:08 26 4
gpt4 key购买 nike

我正在研究在 Iris 数据集上应用受限玻尔兹曼机的示例。本质上,我试图对人民币和LDA进行比较。 LDA 似乎产生了合理的正确输出结果,但 RBM 却不然。根据建议,我使用 skearn.preprocessing.Binarizer 对特征输入进行二值化,并尝试了不同的阈值参数值。我尝试了几种不同的方法来应用二值化,但似乎没有一个适合我。

下面是我根据该用户的版本 User: covariance 修改后的代码版本.

非常感谢任何有用的评论。

from sklearn import linear_model, datasets, preprocessing
from sklearn.cross_validation import train_test_split
from sklearn.pipeline import Pipeline
from sklearn.neural_network import BernoulliRBM
from sklearn.lda import LDA

# import some data to play with
iris = datasets.load_iris()
X = iris.data[:,:2] # we only take the first two features.
Y = iris.target

X_train, X_test, Y_train, Y_test = train_test_split(X, Y, test_size=0.2, random_state=10)

# Models we will use
rbm = BernoulliRBM(random_state=0, verbose=True)
binarizer = preprocessing.Binarizer(threshold=0.01,copy=True)
X_binarized = binarizer.fit_transform(X_train)
hidden_layer = rbm.fit_transform(X_binarized, Y_train)
logistic = linear_model.LogisticRegression()
logistic.coef_ = hidden_layer
classifier = Pipeline(steps=[('rbm', rbm), ('logistic', logistic)])
lda = LDA(n_components=3)


#########################################################################

# Training RBM-Logistic Pipeline
logistic.fit(X_train, Y_train)
classifier.fit(X_binarized, Y_train)

#########################################################################

# Get predictions
print "The RBM model:"
print "Predict: ", classifier.predict(X_test)
print "Real: ", Y_test

print

print "Linear Discriminant Analysis: "
lda.fit(X_train, Y_train)
print "Predict: ", lda.predict(X_test)
print "Real: ", Y_test

最佳答案

RBM 和 LDA 不能直接比较,因为 RBM 本身不执行分类。尽管您将其用作最后的逻辑回归的特征工程步骤,但 LDA 本身就是一个分类器 - 因此比较意义不大。

scikit learn 中的 BernoulliRBM 仅处理二进制输入。 iris 数据集没有合理的二值化,因此您不会获得任何有意义的输出。

关于python - Scikit-learn 中的受限玻尔兹曼机 : Iris Classification,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/32744033/

26 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com