gpt4 book ai didi

python - 为什么使用 TensorFlow 计算简单线性回归时得到 [nan]?

转载 作者:行者123 更新时间:2023-11-30 09:34:06 25 4
gpt4 key购买 nike

当我使用 TensorFlow 计算简单的线性回归时,我得到 [nan],包括:w、b 和损失。

这是我的代码:

import tensorflow as tf

w = tf.Variable(tf.zeros([1]), tf.float32)
b = tf.Variable(tf.zeros([1]), tf.float32)
x = tf.placeholder(tf.float32)
y = tf.placeholder(tf.float32)

liner = w*x+b

loss = tf.reduce_sum(tf.square(liner-y))

train = tf.train.GradientDescentOptimizer(1).minimize(loss)

sess = tf.Session()

x_data = [1000, 2000, 3000, 4000, 5000, 6000, 7000, 8000, 9000]
y_data = [265000, 324000, 340000, 412000, 436000, 490000, 574000, 585000, 680000]

sess.run(tf.global_variables_initializer())

for i in range(1000):
sess.run(train, {x: x_data, y: y_data})

nw, nb, nloss = sess.run([w, b, loss], {x: x_data, y: y_data})

print(nw, nb, nloss)

输出:

[ nan] [ nan] nan

Process finished with exit code 0

为什么会发生这种情况,如何解决?

最佳答案

使用如此高的学习率(在你的例子中为1)你已经溢出了。尝试使用 0.001 的学习率。此外,您的数据需要除以 1000 并增加迭代次数,它应该可以工作。这是我测试过的代码,运行良好。

x_data = [1, 2, 3, 4, 5, 6, 7, 8, 9]
y_data = [265, 324, 340, 412, 436, 490, 574, 585, 680]

plt.plot(x_data, y_data, 'ro', label='Original data')
plt.legend()
plt.show()

W = tf.Variable(tf.random_uniform([1], 0, 1))
b = tf.Variable(tf.zeros([1]))
y = W * x_data + b

loss = tf.reduce_mean(tf.square(y - y_data))

optimizer = tf.train.GradientDescentOptimizer(0.001)
train = optimizer.minimize(loss)
init = tf.initialize_all_variables()

sess = tf.Session()
sess.run(init)

for step in range(0,50000):
sess.run(train)
print(step, sess.run(loss))
print (step, sess.run(W), sess.run(b))

plt.plot(x_data, y_data, 'ro')
plt.plot(x_data, sess.run(W) * x_data + sess.run(b))
plt.legend()
plt.show()

关于python - 为什么使用 TensorFlow 计算简单线性回归时得到 [nan]?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/47928554/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com