gpt4 book ai didi

python - 使用 TFlearn 塑造线性回归数据

转载 作者:太空宇宙 更新时间:2023-11-03 15:04:46 25 4
gpt4 key购买 nike

我正在尝试扩展the tflearn example for linear regression将列数增加到 21。

from trafficdata import X,Y

import tflearn

print(X.shape) #(1054, 21)
print(Y.shape) #(1054,)

# Linear Regression graph
input_ = tflearn.input_data(shape=[None,21])
linear = tflearn.single_unit(input_)
regression = tflearn.regression(linear, optimizer='sgd', loss='mean_square',
metric='R2', learning_rate=0.01)
m = tflearn.DNN(regression)
m.fit(X, Y, n_epoch=1000, show_metric=True, snapshot_epoch=False)

print("\nRegression result:")
print("Y = " + str(m.get_weights(linear.W)) +
"*X + " + str(m.get_weights(linear.b)))

但是,tflearn 提示:

Traceback (most recent call last):
File "linearregression.py", line 16, in <module>
m.fit(X, Y, n_epoch=1000, show_metric=True, snapshot_epoch=False)
File "/usr/local/lib/python3.5/dist-packages/tflearn/models/dnn.py", line 216, in fit
callbacks=callbacks)
File "/usr/local/lib/python3.5/dist-packages/tflearn/helpers/trainer.py", line 339, in fit
show_metric)
File "/usr/local/lib/python3.5/dist-packages/tflearn/helpers/trainer.py", line 818, in _train
feed_batch)
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py", line 789, in run
run_metadata_ptr)
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py", line 975, in _run
% (np_val.shape, subfeed_t.name, str(subfeed_t.get_shape())))
ValueError: Cannot feed value of shape (64,) for Tensor 'TargetsData/Y:0', which has shape '(21,)'

我发现形状 (64, ) 来自 tflearn.regression() 的默认批量大小。

我需要转换标签(Y)吗?通过什么方式?

谢谢!

最佳答案

我也尝试做同样的事情。我做了这些更改以使其正常工作

# linear = tflearn.single_unit(input_)
linear = tflearn.fully_connected(input_, 1, activation='linear')

我的猜测是,对于功能>1,您不能使用tflearn.single_unit()。您可以添加额外的完全连接层,但最后一层必须只有 1 个神经元,因为 Y.shape=(?,1)

关于python - 使用 TFlearn 塑造线性回归数据,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/44763935/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com