gpt4 book ai didi

python - tensorflow GradientDescentOptimizer 不更新变量?

转载 作者:行者123 更新时间:2023-11-30 09:43:57 26 4
gpt4 key购买 nike

我是机器学习新手。我从最简单的使用 softmax 和梯度下降的 mnist 手写图像分类示例开始。通过引用其他一些例子,我提出了自己的逻辑回归:

import tensorflow as tf
import numpy as np


(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
x_train = np.float32(x_train / 255.0)
x_test = np.float32(x_test / 255.0)

X = tf.placeholder(tf.float32, [None, 28, 28])
Y = tf.placeholder(tf.uint8, [100])

XX = tf.reshape(X, [-1, 784])

W = tf.Variable(tf.zeros([784, 10]))
b = tf.Variable(tf.zeros([10]))

def err(x, y):
predictions = tf.matmul(x, W) + b
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=tf.reshape(y, [-1, 1]), logits=predictions))
# value = tf.reduce_mean(y * tf.log(predictions))
# loss = -tf.reduce_mean(tf.one_hot(y, 10) * tf.log(predictions)) * 100.
return loss

# cost = err(np.reshape(x_train[:100], (-1, 784)), y_train[:100])
cost = err(tf.reshape(X, (-1, 784)), Y)

optimizer = tf.train.GradientDescentOptimizer(0.005).minimize(cost)


init = tf.global_variables_initializer()
sess = tf.Session()
sess.run(init)



# temp = sess.run(tf.matmul(XX, W) + b, feed_dict={X: x_train[:100]})

temp = sess.run(cost, feed_dict={X: x_train[:100], Y: y_train[:100]})
print(temp)
# print(temp.dtype)
# print(type(temp))

for i in range(100):
sess.run(optimizer, feed_dict={X: x_train[i * 100: 100 * (i + 1)], Y: y_train[i * 100: 100 * (i + 1)]})
# sess.run(optimizer, feed_dict={X: x_train[: 100], Y: y_train[:100]})

temp = sess.run(cost, feed_dict={X: x_train[:100], Y: y_train[:100]})
print(temp)


sess.close()

我尝试运行优化器一些迭代,用训练图像数据和标签提供数据。根据我的理解,在优化器运行期间,应该更新“W”和“b”变量,以便模型在训练前后产生不同的结果。但使用这段代码,优化器运行之前和之后模型的打印成本是相同的。可能有什么问题会导致这种情况发生?

最佳答案

您正在用零初始化权重矩阵W,因此,所有参数在每次权重更新时都会收到相同的梯度值。对于权重初始化,请使用 tf.truncated_normal()、tf.random_normal()、tf.contrib.layers.xavier_initializer() 或其他方法,但不是零。

This是一个类似的问题。

关于python - tensorflow GradientDescentOptimizer 不更新变量?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/55006469/

26 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com