gpt4 book ai didi

python - Tensorflow Iris 数据集永远不会收敛

转载 作者:行者123 更新时间:2023-11-30 09:08:33 25 4
gpt4 key购买 nike

我正在尝试在 Iris 数据集上运行标准神经网络。标签是单列,可以具有值 0、1、2,具体取决于物种。我将特征转置到 x 轴,并将示例转置到 y 轴。

值得关注的领域:成本函数 - 每个人似乎都使用预构建的函数,但由于我的数据不是一次性编码的,所以我使用标准损失。优化器 - 我将它用作黑匣子,不确定我是否能够正确更新成本。

预先感谢您的帮助。

import tensorflow as tf
import numpy as np
import pandas as pd

import tensorflow as tf


def create_layer(previous_layer, weight, bias, activation_function=None):
z = tf.add(tf.matmul(weight, previous_layer), bias)
if activation_function is None:
return z
a = activation_function(z)
return a


def cost_compute(prediction, correct_values):
return tf.nn.softmax_cross_entropy_with_logits(logits = prediction, labels = correct_values)

input_features = 4
n_hidden_units1 = 10
n_hidden_units2 = 14
n_hidden_units3 = 12
n_hidden_units4 = 1

rate = .000001

weights = dict(
w1=tf.Variable(tf.random_normal([n_hidden_units1, input_features])),
w2=tf.Variable(tf.random_normal([n_hidden_units2, n_hidden_units1])),
w3=tf.Variable(tf.random_normal([n_hidden_units3, n_hidden_units2])),
w4=tf.Variable(tf.random_normal([n_hidden_units4, n_hidden_units3]))
)

biases = dict(
b1=tf.Variable(tf.zeros([n_hidden_units1, 1])),
b2=tf.Variable(tf.zeros([n_hidden_units2, 1])),
b3=tf.Variable(tf.zeros([n_hidden_units3, 1])),
b4=tf.Variable(tf.zeros([n_hidden_units4, 1]))
)

train = pd.read_csv("/Users/yazen/Desktop/datasets/iris_training.csv")
test = pd.read_csv("/Users/yazen/Desktop/datasets/iris_test.csv")

train.columns = ['sepal length', 'sepal width', 'petal length', 'petal width', 'species']
test.columns = ['sepal length', 'sepal width', 'petal length', 'petal width', 'species']

train_labels = np.expand_dims(train['species'].as_matrix(), 1)
test_labels = np.expand_dims(test['species'].as_matrix(), 1)

train_features = train.drop('species', axis=1)
test_features = test.drop('species', axis=1)

test_labels = test_labels.transpose()
train_labels = train_labels.transpose()
test_features = test_features.transpose()
train_features = train_features.transpose()

x = tf.placeholder("float32", [4, None], name="asdfadsf")
y = tf.placeholder("float32", [1, None], name="asdfasdf2")

layer = create_layer(x, weights['w1'], biases['b1'], tf.nn.relu)
layer = create_layer(layer, weights['w2'], biases['b2'], tf.nn.relu)
layer = create_layer(layer, weights['w3'], biases['b3'], tf.nn.relu)
Z4 = create_layer(layer, weights['w4'], biases['b4'])
cost = cost_compute(Z4, y)

with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
for iteration in range(1,50):
optimizer = tf.train.GradientDescentOptimizer(learning_rate=rate).minimize(cost)
_, c = sess.run([optimizer, cost], feed_dict={x: train_features, y: train_labels})
print("Iteration " + str(iteration) + " cost: " + str(c))

prediction = tf.equal(Z4, y)
accuracy = tf.reduce_mean(tf.cast(prediction, "float"))
print(sess.run(Z4, feed_dict={x: train_features, y: train_labels}))
print(accuracy.eval({x: train_features, y: train_labels}))

最佳答案

由于您遇到分类问题,因此需要将标签转换为 one-hot 形式。您可以使用tf.one_hot以此目的。此外,您还可以对成本应用 tf.reduce_mean,如下面的示例所示(取自 here )。另外,你的学习率对我来说似乎太小了。

  mnist = input_data.read_data_sets(FLAGS.data_dir, one_hot=True)

x = tf.placeholder(tf.float32, [None, 784])
W = tf.Variable(tf.zeros([784, 10]))
b = tf.Variable(tf.zeros([10]))
y = tf.matmul(x, W) + b

# Define loss and optimizer
y_ = tf.placeholder(tf.float32, [None, 10])

cross_entropy = tf.reduce_mean(
tf.nn.softmax_cross_entropy_with_logits(labels=y_, logits=y))
train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)

sess = tf.InteractiveSession()
tf.global_variables_initializer().run()
# Train
for _ in range(1000):
batch_xs, batch_ys = mnist.train.next_batch(100)
sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})

# Test trained model
correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
print(sess.run(accuracy, feed_dict={x: mnist.test.images,
y_: mnist.test.labels}))

关于python - Tensorflow Iris 数据集永远不会收敛,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/46019893/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com