gpt4 book ai didi

machine-learning - tensorflow 的简单线性回归

转载 作者:行者123 更新时间:2023-11-30 08:39:50 27 4
gpt4 key购买 nike

我是 tensorflow 和机器学习的初学者。我想通过 tensorflow 尝试一个简单的线性回归示例。

但是3700个epoch之后损失就无法减少了。不知道怎么了?

显然,我们得到了W = 3.52,b = 2.8865。所以y = 3.52*x + 2.8865。测试数据x = 11, y = 41.6065时。但这是错误的。因为训练数据x = 10, y = 48.712

下面发布了代码和损失。

#Goal: predict the house price in 2017 by linear regression method
#Step: 1. load the original data
# 2. define the placeholder and variable
# 3. linear regression method
# 4. launch the graph

from __future__ import print_function

import os
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt

os.environ["CUDA_VISIBLE_DEVICES"] = '0'

# 1. load the original data
price = np.asarray([6.757, 12.358, 10.091, 11.618, 14.064,
16.926, 17.673, 22.271, 26.905, 34.742, 48.712])
year = np.asarray([0,1,2,3,4,5,6,7,8,9,10])
n_samples = price.shape[0]


# 2. define the placeholder and variable
x = tf.placeholder("float")
y_ = tf.placeholder("float")


W = tf.Variable(np.random.randn())
b = tf.Variable(np.random.randn())


# 3. linear regression method
y = tf.add(tf.multiply(x, W), b)

loss = tf.reduce_mean(tf.square(y - y_))/(2*n_samples)
training_step = tf.train.GradientDescentOptimizer(0.01).minimize(loss)


# 4. launch the graph
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())

for epoch in range(10000):
for (year_epoch, price_epoch) in zip(year, price):
sess.run(training_step, feed_dict = {x: year_epoch, y_: price_epoch})

if (epoch+1) % 50 == 0:
loss_np = sess.run(loss, feed_dict={x: year, y_: price})
print("Epoch: ", '%04d' % (epoch+1), "loss = ", "{:.9f}".format(loss_np), "W = ", sess.run(W), "b = ", sess.run(b))

# print "Training finish"
training_loss = sess.run(loss, feed_dict = {x: year, y_: price})
print("Training cost = ", training_loss, "W = ", sess.run(W), "b = ", sess.run(b), '\n')

损失为:

Epoch:  0050 loss =  1.231071353 W =  3.88227 b =  0.289058
Epoch: 0100 loss = 1.207471132 W = 3.83516 b = 0.630129
Epoch: 0150 loss = 1.189429402 W = 3.79423 b = 0.926415
Epoch: 0200 loss = 1.175611973 W = 3.75868 b = 1.1838
Epoch: 0250 loss = 1.165009260 W = 3.72779 b = 1.40738
Epoch: 0300 loss = 1.156855702 W = 3.70096 b = 1.60161
Epoch: 0350 loss = 1.150570631 W = 3.67766 b = 1.77033
Epoch: 0400 loss = 1.145712137 W = 3.65741 b = 1.9169
Epoch: 0450 loss = 1.141945601 W = 3.63982 b = 2.04422
Epoch: 0500 loss = 1.139016271 W = 3.62455 b = 2.15483
Epoch: 0550 loss = 1.136731029 W = 3.61127 b = 2.25091
Epoch: 0600 loss = 1.134940267 W = 3.59974 b = 2.33437
Epoch: 0650 loss = 1.133531928 W = 3.58973 b = 2.40688
Epoch: 0700 loss = 1.132419944 W = 3.58103 b = 2.46986
Epoch: 0750 loss = 1.131537557 W = 3.57347 b = 2.52458
Epoch: 0800 loss = 1.130834818 W = 3.5669 b = 2.57211
Epoch: 0850 loss = 1.130271792 W = 3.5612 b = 2.6134
Epoch: 0900 loss = 1.129818439 W = 3.55625 b = 2.64927
Epoch: 0950 loss = 1.129452229 W = 3.55194 b = 2.68042
Epoch: 1000 loss = 1.129154325 W = 3.5482 b = 2.70749
Epoch: 1050 loss = 1.128911495 W = 3.54496 b = 2.731
Epoch: 1100 loss = 1.128711581 W = 3.54213 b = 2.75143
Epoch: 1150 loss = 1.128546953 W = 3.53968 b = 2.76917
Epoch: 1200 loss = 1.128411174 W = 3.53755 b = 2.78458
Epoch: 1250 loss = 1.128297567 W = 3.53571 b = 2.79797
Epoch: 1300 loss = 1.128202677 W = 3.5341 b = 2.8096
Epoch: 1350 loss = 1.128123403 W = 3.5327 b = 2.81971
Epoch: 1400 loss = 1.128056765 W = 3.53149 b = 2.82849
Epoch: 1450 loss = 1.128000259 W = 3.53044 b = 2.83611
Epoch: 1500 loss = 1.127952814 W = 3.52952 b = 2.84274
Epoch: 1550 loss = 1.127912283 W = 3.52873 b = 2.84849
Epoch: 1600 loss = 1.127877355 W = 3.52804 b = 2.85349
Epoch: 1650 loss = 1.127847791 W = 3.52744 b = 2.85783
Epoch: 1700 loss = 1.127822518 W = 3.52692 b = 2.8616
Epoch: 1750 loss = 1.127801418 W = 3.52646 b = 2.86488
Epoch: 1800 loss = 1.127782702 W = 3.52607 b = 2.86773
Epoch: 1850 loss = 1.127766728 W = 3.52573 b = 2.8702
Epoch: 1900 loss = 1.127753139 W = 3.52543 b = 2.87234
Epoch: 1950 loss = 1.127740979 W = 3.52517 b = 2.87421
Epoch: 2000 loss = 1.127731323 W = 3.52495 b = 2.87584
Epoch: 2050 loss = 1.127722263 W = 3.52475 b = 2.87725
Epoch: 2100 loss = 1.127714872 W = 3.52459 b = 2.87847
Epoch: 2150 loss = 1.127707958 W = 3.52444 b = 2.87953
Epoch: 2200 loss = 1.127702117 W = 3.52431 b = 2.88045
Epoch: 2250 loss = 1.127697825 W = 3.5242 b = 2.88126
Epoch: 2300 loss = 1.127693415 W = 3.52411 b = 2.88195
Epoch: 2350 loss = 1.127689362 W = 3.52402 b = 2.88255
Epoch: 2400 loss = 1.127686620 W = 3.52395 b = 2.88307
Epoch: 2450 loss = 1.127683759 W = 3.52389 b = 2.88352
Epoch: 2500 loss = 1.127680898 W = 3.52383 b = 2.88391
Epoch: 2550 loss = 1.127679348 W = 3.52379 b = 2.88425
Epoch: 2600 loss = 1.127677798 W = 3.52374 b = 2.88456
Epoch: 2650 loss = 1.127675653 W = 3.52371 b = 2.88483
Epoch: 2700 loss = 1.127674222 W = 3.52368 b = 2.88507
Epoch: 2750 loss = 1.127673268 W = 3.52365 b = 2.88526
Epoch: 2800 loss = 1.127672315 W = 3.52362 b = 2.88543
Epoch: 2850 loss = 1.127671123 W = 3.5236 b = 2.88559
Epoch: 2900 loss = 1.127670288 W = 3.52358 b = 2.88572
Epoch: 2950 loss = 1.127670050 W = 3.52357 b = 2.88583
Epoch: 3000 loss = 1.127669215 W = 3.52356 b = 2.88592
Epoch: 3050 loss = 1.127668500 W = 3.52355 b = 2.88599
Epoch: 3100 loss = 1.127668381 W = 3.52354 b = 2.88606
Epoch: 3150 loss = 1.127667665 W = 3.52353 b = 2.88615
Epoch: 3200 loss = 1.127667546 W = 3.52352 b = 2.88621
Epoch: 3250 loss = 1.127667069 W = 3.52351 b = 2.88626
Epoch: 3300 loss = 1.127666950 W = 3.5235 b = 2.8863
Epoch: 3350 loss = 1.127666354 W = 3.5235 b = 2.88633
Epoch: 3400 loss = 1.127666593 W = 3.5235 b = 2.88637
Epoch: 3450 loss = 1.127666593 W = 3.52349 b = 2.8864
Epoch: 3500 loss = 1.127666235 W = 3.52349 b = 2.88644
Epoch: 3550 loss = 1.127665997 W = 3.52348 b = 2.88646
Epoch: 3600 loss = 1.127665639 W = 3.52348 b = 2.88648
Epoch: 3650 loss = 1.127665639 W = 3.52348 b = 2.88649
Epoch: 3700 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 3750 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 3800 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 3850 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 3900 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 3950 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 4000 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 4050 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 4100 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 4150 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 4200 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 4250 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 4300 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 4350 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 4400 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 4450 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 4500 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 4550 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 4600 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 4650 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 4700 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 4750 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 4800 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 4850 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 4900 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 4950 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 5000 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 5050 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 5100 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 5150 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 5200 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 5250 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 5300 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 5350 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 5400 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 5450 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 5500 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 5550 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 5600 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 5650 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 5700 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 5750 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 5800 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 5850 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 5900 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 5950 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 6000 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 6050 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 6100 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 6150 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 6200 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 6250 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 6300 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 6350 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 6400 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 6450 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 6500 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 6550 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 6600 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 6650 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 6700 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 6750 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 6800 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 6850 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 6900 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 6950 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 7000 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 7050 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 7100 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 7150 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 7200 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 7250 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 7300 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 7350 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 7400 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 7450 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 7500 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 7550 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 7600 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 7650 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 7700 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 7750 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 7800 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 7850 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 7900 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 7950 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 8000 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 8050 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 8100 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 8150 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 8200 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 8250 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 8300 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 8350 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 8400 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 8450 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 8500 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 8550 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 8600 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 8650 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 8700 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 8750 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 8800 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 8850 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 8900 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 8950 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 9000 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 9050 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 9100 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 9150 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 9200 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 9250 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 9300 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 9350 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 9400 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 9450 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 9500 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 9550 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 9600 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 9650 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 9700 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 9750 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 9800 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 9850 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 9900 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 9950 loss = 1.127665997 W = 3.52348 b = 2.8865
Epoch: 10000 loss = 1.127665997 W = 3.52348 b = 2.8865
Training cost = 1.12767 W = 3.52348 b = 2.8865

最佳答案

您假设预测输出呈直线的假设是不正确的。检查年份和价格的情节如何。 enter image description here

因此,您所采取的线性假设将通过满足尽可能多的输入点来尝试最好地拟合直线,以降低成本。因此,当您测试范围之外的点时,它将在直线上预测最适合您提供的输入集的优化值。

现在,您提到了两个问题。

<强>1。成本没有下降:尝试降低学习率。您的成本肯定会下降。

<强>2。你的year = 11的输出是错误的:原因我已经在上面提到过。你需要做的是改变假设。包括一个平方项,然后检查它。示例:y = ax^2 + bx + c。您将更好地拟合这个假设方程。

关于machine-learning - tensorflow 的简单线性回归,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/46622270/

27 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com