gpt4 book ai didi

python-3.x - 为什么我的 Keras 模型在 Iris 数据集上表现如此糟糕?

转载 作者:行者123 更新时间:2023-11-30 09:27:19 25 4
gpt4 key购买 nike

我正在处理this Keras tutorial我发现了一些有趣的事情。

我已经使用 sklearn 训练了我的逻辑回归模型,并且它的表现相当不错:

import seaborn as sns
import numpy as np
from sklearn.cross_validation import train_test_split
from sklearn.linear_model import LogisticRegressionCV
from keras.models import Sequential
from keras.layers.core import Dense, Activation
from keras.utils import np_utils

# Load the iris dataset from seaborn.
iris = sns.load_dataset("iris")

# Use the first 4 variables to predict the species.
X, y = iris.values[:, 0:4], iris.values[:, 4]

# Split both independent and dependent variables in half
# for cross-validation
train_X, test_X, train_y, test_y = train_test_split(X, y, train_size=0.5, random_state=0)

# Train a scikit-learn log-regression model
lr = LogisticRegressionCV()
lr.fit(train_X, train_y)

# Test the model. Print the accuracy on the test data
pred_y = lr.predict(test_X)
print("Accuracy is {:.2f}".format(lr.score(test_X, test_y))) # Accuracy is 0.83

83% 已经相当不错了,但是使用深度学习我们应该能够做得更好。我训练 Keras 模型...

# Define a one-hot encoding of variables in an array.
def one_hot_encode_object_array(arr):
'''One hot encode a numpy array of objects (e.g. strings)'''
uniques, ids = np.unique(arr, return_inverse=True)
return np_utils.to_categorical(ids, len(uniques))

# One-hot encode the train and test y's
train_y_ohe = one_hot_encode_object_array(train_y)
test_y_ohe = one_hot_encode_object_array(test_y)

# Build the keras model

model = Sequential()
# 4 features in the input layer (the four flower measurements)
# 16 hidden units
model.add(Dense(16, input_shape=(4,)))
model.add(Activation('sigmoid'))
# 3 classes in the ouput layer (corresponding to the 3 species)
model.add(Dense(3))
model.add(Activation('softmax'))
model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])

# Train the keras model
model.fit(train_X, train_y_ohe, verbose=0, batch_size=1)

# Test the model. Print the accuracy on the test data
loss, accuracy = model.evaluate(test_X, test_y_ohe, verbose=0)
print("Accuracy is {:.2f}".format(accuracy)) # Accuracy is 0.60????

当我训练 Keras 模型时,我的准确性实际上比逻辑回归模型差。

虽然这对于某些数据有意义,但对于 Keras 顺序模型来说,令人难以置信的线性可分离数据(如虹膜)应该非常容易学习。我尝试过将隐藏层数增加到 32、64 和 128,但精度没有提高。

下面显示了鸢尾花数据(特别是自变量)作为物种(因变量)的函数:

Iris data

为什么我的模型表现这么差?

最佳答案

我替换了 one_hot_encoding,只使用 keras 的 sparse_categorical_crossentropy

显而易见的尝试是:增加学习时期的数量(默认 10,让我们尝试 100)。

代码

from sklearn.datasets import load_iris
from sklearn.cross_validation import train_test_split
from sklearn.linear_model import LogisticRegressionCV
import numpy as np

from keras.models import Sequential
from keras.layers import Dense, Activation

# Load the iris dataset from seaborn.
iris = load_iris()

# Use the first 4 variables to predict the species.
X, y = iris.data[:, :4], iris.target

# Split both independent and dependent variables in half
# for cross-validation
train_X, test_X, train_y, test_y = train_test_split(X, y, train_size=0.5, random_state=0)

# Train a scikit-learn log-regression model
lr = LogisticRegressionCV()
lr.fit(train_X, train_y)

# Test the model. Print the accuracy on the test data
pred_y = lr.predict(test_X)
print("Accuracy is {:.2f}".format(lr.score(test_X, test_y))) # Accuracy is 0.83


# Build the keras model

model = Sequential()
# 4 features in the input layer (the four flower measurements)
# 16 hidden units
model.add(Dense(16, input_shape=(4,)))
model.add(Activation('sigmoid'))
# 3 classes in the ouput layer (corresponding to the 3 species)
model.add(Dense(3))
model.add(Activation('softmax'))
model.compile(loss='sparse_categorical_crossentropy', optimizer='adam', metrics=['accuracy'])

# Train the keras model
model.fit(train_X, train_y, verbose=1, batch_size=1, nb_epoch=100)

# Test the model. Print the accuracy on the test data
loss, accuracy = model.evaluate(test_X, test_y, verbose=0)
print("Accuracy is {:.2f}".format(accuracy))

输出

Accuracy is 0.83
Accuracy is 0.99

关于python-3.x - 为什么我的 Keras 模型在 Iris 数据集上表现如此糟糕?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/39810655/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com