gpt4 book ai didi

python-3.x - 在 Keras (Python) 中使用神经网络进行样本外预测

转载 作者:行者123 更新时间:2023-12-04 15:18:30 28 4
gpt4 key购买 nike

我正在使用窗口方法进行时间序列预测练习,但我很难理解如何进行样本外预测。
这是代码:

def windowed_dataset(series, window_size, batch_size, shuffle_buffer):
dataset = tf.data.Dataset.from_tensor_slices(series)
dataset = dataset.window(window_size + 1, shift=1, drop_remainder=True)
dataset = dataset.flat_map(lambda window: window.batch(window_size + 1))
dataset = dataset.shuffle(shuffle_buffer).map(lambda window: (window[:-1], window[-1]))
dataset = dataset.batch(batch_size).prefetch(1)
return dataset

dataset = windowed_dataset(x_train, window_size, batch_size, shuffle_buffer_size)
函数 windowed_dataset拆分单变量时间序列 series成矩阵。想象一下,我们有一个数据集如下
dataset = tf.data.Dataset.range(10)
for val in dataset:
print(val.numpy())
0
1
2
3
4
5
6
7
8
9
windowed_dataset函数转换 series使用 x features 进入窗口在左侧和 y labels在右边。
[2 3 4 5] [6]
[4 5 6 7] [8]
[3 4 5 6] [7]
[1 2 3 4] [5]
[5 6 7 8] [9]
[0 1 2 3] [4]
下一步,我们在训练 dataset 上实现神经网络模型如下:
model = tf.keras.models.Sequential([
tf.keras.layers.Dense(10, input_shape=[window_size], activation="relu"),
tf.keras.layers.Dense(10, activation="relu"),
tf.keras.layers.Dense(1)
])

model.compile(loss="mse", optimizer=tf.keras.optimizers.SGD(lr=1e-6, momentum=0.9))
model.fit(dataset,epochs=100,verbose=0)
到这里为止,我对代码很好。但是,我很难理解如下所示的样本外预测:
forecast = []
for time in range(len(series) - window_size):
forecast.append(model.predict(series[time:time + window_size][np.newaxis]))
forecast = forecast[split_time-window_size:]
有人可以向我解释为什么我们在这里使用循环 time in range(len(series) - window_size) ?为什么不简单地做 model.predict(dataset_validation)用于验证部分和 model.predict(dataset)培训部分?
我不明白 for loop 的必要性因为这不是滚动预测,所以我们不会重新训练模型。有人可以向我解释一下吗?
虽然我理解为什么数据科学社区构建了 dataset这样,当我们拆分 X 时,我个人觉得它更清晰。和 y并做 model.fit如下 model.fit(X,y,epochs=100,verbose=0)predict如下 model.predict(X)

最佳答案

for 循环按顺序返回预测,而如果您调用 model.predict(dataset_validation) ,您将获得混洗顺序的预测(假设您混洗了数据集)。
至于使用数据集的点 - 它可以帮助代码组织。如果您不想,则无需使用。

关于python-3.x - 在 Keras (Python) 中使用神经网络进行样本外预测,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/63859402/

28 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com