gpt4 book ai didi

python - skflow 回归预测多个值

转载 作者:太空狗 更新时间:2023-10-29 21:45:49 24 4
gpt4 key购买 nike

我正在尝试预测一个时间序列:给定 50 个先前的值,我想预测接下来的 5 个值。

为此,我使用了 skflow 包(基于 TensorFlow),这个问题相对接近 Boston example provided in the Github repo .

我的代码如下:

%matplotlib inline
import pandas as pd

import skflow
from sklearn import cross_validation, metrics
from sklearn import preprocessing

filepath = 'CSV/FILE.csv'
ts = pd.Series.from_csv(filepath)

nprev = 50
deltasuiv = 5

def load_data(data, n_prev = nprev, delta_suiv=deltasuiv):

docX, docY = [], []
for i in range(len(data)-n_prev-delta_suiv):
docX.append(np.array(data[i:i+n_prev]))
docY.append(np.array(data[i+n_prev:i+n_prev+delta_suiv]))
alsX = np.array(docX)
alsY = np.array(docY)

return alsX, alsY

X, y = load_data(ts.values)
# Scale data to 0 mean and unit std dev.
scaler = preprocessing.StandardScaler()
X = scaler.fit_transform(X)
X_train, X_test, y_train, y_test = cross_validation.train_test_split(X, y,
test_size=0.2, random_state=42)
regressor = skflow.TensorFlowDNNRegressor(hidden_units=[30, 50],
steps=5000, learning_rate=0.1, batch_size=1)
regressor.fit(X_train, y_train)
score = metrics.mean_squared_error(regressor.predict(X_test), y_test)
print('MSE: {0:f}'.format(score))

这导致:

ValueError: y_true and y_pred have different number of output (1!=5)

在训练结束时。

当我尝试预测时,我遇到了同样的问题

ypred = regressor.predict(X_test)
print ypred.shape, y_test.shape

(200, 1) (200, 5)

因此,我们可以看到该模型以某种方式仅预测 1 个值,而不是 5 个想要/希望的值。

我如何使用同一个模型来预测多个值的值?

最佳答案

我刚刚在 skflow 中添加了对多输出回归的支持,因为这 #e443c734 ,所以请重新安装软件包再试一次。如果还不行,请在Github上跟进。

我还在 examples folder 中添加了一个多输出回归示例:

# Create random dataset.
rng = np.random.RandomState(1)
X = np.sort(200 * rng.rand(100, 1) - 100, axis=0)
y = np.array([np.pi * np.sin(X).ravel(), np.pi * np.cos(X).ravel()]).T

# Fit regression DNN model.
regressor = skflow.TensorFlowDNNRegressor(hidden_units=[5, 5])
regressor.fit(X, y)
score = mean_squared_error(regressor.predict(X), y)
print("Mean Squared Error: {0:f}".format(score))

关于python - skflow 回归预测多个值,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/34224826/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com