gpt4 book ai didi

machine-learning - Keras 回归每次都会对我的输入给出不同的预测

转载 作者:行者123 更新时间:2023-11-30 08:58:55 24 4
gpt4 key购买 nike

我使用以下代码构建了一个 Keras 回归器:

from keras.models import Sequential
from keras.layers import Dense
from keras.wrappers.scikit_learn import KerasRegressor
from sklearn.model_selection import cross_val_score
from sklearn.model_selection import KFold
from sklearn.preprocessing import StandardScaler
from sklearn.pipeline import Pipeline


import numpy as ny
import pandas

from numpy.random import seed
seed(1)
from tensorflow import set_random_seed
set_random_seed(2)

X = ny.array([[1,2], [3,4], [5,6], [7,8], [9,10]])
sc_X=StandardScaler()
X_train = sc_X.fit_transform(X)

Y = ny.array([3, 4, 5, 6, 7])
Y=ny.reshape(Y,(-1,1))
sc_Y=StandardScaler()
Y_train = sc_Y.fit_transform(Y)

N = 5

def brain():
#Create the brain
br_model=Sequential()
br_model.add(Dense(3, input_dim=2, kernel_initializer='normal',activation='relu'))
br_model.add(Dense(2, kernel_initializer='normal',activation='relu'))
br_model.add(Dense(1,kernel_initializer='normal'))

#Compile the brain
br_model.compile(loss='mean_squared_error',optimizer='adam')
return br_model


def predict(X,sc_X,sc_Y,estimator):
prediction = estimator.predict(sc_X.fit_transform(X))
return sc_Y.inverse_transform(prediction)

estimator = KerasRegressor(build_fn=brain, epochs=1000, batch_size=5,verbose=0)
# print "Done"


estimator.fit(X_train,Y_train)
prediction = estimator.predict(X_train)


print predict(X,sc_X,sc_Y,estimator)

X_test = ny.array([[1.5,4.5], [7,8], [9,10]])
print predict(X_test,sc_X,sc_Y,estimator)

我面临的问题是代码没有预测相同的值(例如,它在第一个预测 (X) 中为 [9,10] 预测 6.64,在第二个预测 (X) 中为 [9,10] 预测 6.49 ( X_测试))完整的输出是这样的:

[2.9929883 4.0016675 5.0103474 6.0190268 6.6434317]
[3.096634 5.422326 6.4955378]

为什么我会得到不同的值以及如何解决这些问题?

最佳答案

问题出在这行代码:

prediction = estimator.predict(sc_X.fit_transform(X))

每次预测新数据的值时,您都会拟合一个新的缩放器。这就是差异的来源。尝试:

prediction = estimator.predict(sc_X.transform(X))

在本例中,您使用预先训练的缩放器。

关于machine-learning - Keras 回归每次都会对我的输入给出不同的预测,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/48507991/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com