gpt4 book ai didi

python - RNN 模型上的网格搜索时训练数据的形状不正确

转载 作者:太空宇宙 更新时间:2023-11-03 20:03:34 25 4
gpt4 key购买 nike

请帮助我编写以下代码,当我尝试直接在数据上拟合模型时,该代码运行良好,但在网格搜索上失败(我已注释掉直接 model.fit()我的 grid.fit() 语句末尾的部分,这给了我满意的结果。还请告诉我,我在 reshape 训练数据时是否错误,因为我对这个领域非常陌生。我的预处理数据集mt的形状为(96,variable)(列表的列表)

x_train=[]
y_train=[]

from keras.preprocessing.sequence import pad_sequences
padded = pad_sequences(mt)
for i in range(len(padded)):
x_train.append(padded[i,:-1])
y_train.append(padded[i,1:])

x_train=np.reshape(np.array(x_train),(len(x_train),len(x_train[0]),1))
#converted to 3d tensor of (batch_size,time_steps,feature_dim)

for i in range(len(y_train)):
#print(len(lab))
y_train[i]=to_categorical(y_train[i], num_classes=1503)
y_train=np.array(y_train, dtype='int32')

###################MY MODEL####################
def get_model():
model=Sequential()
model.add(GRU(1, implementation=1, activity_regularizer=regularizers.l1(0.01), return_sequences=True, input_shape=(None, 1)))
model.add(TimeDistributed(Dense(1503, activation='softmax')))

print(model.summary())

model.compile(loss='categorical_crossentropy',
optimizer='Nadam', metrics = ['accuracy'])
return model

np.random.seed(7)
from keras.wrappers.scikit_learn import KerasClassifier
mmodel=KerasClassifier(build_fn=get_model)
from sklearn.model_selection import GridSearchCV
batch_size=[1, 4, 6, 8, 12]
epochs=[10,20,30]
#optimizer=['SGD','RMSprop','Adagrad','Adadelta','Adam','Adamax','Nadam'] #, optimizer=optimizer
param_grid=dict(batch_size=batch_size, epochs=epochs)
grid=GridSearchCV(estimator=mmodel,param_grid=param_grid, n_jobs=1, cv=3)
grid_result=grid.fit(x_train,y_train)
#model.fit(x_train,y_train, epochs=10, verbose=1, validation_split=False, batch_size=1, shuffle=True,
# callbacks=False)

最佳答案

实际上,您不需要 reshape 训练数据x_train。您必须更改 GRU 层的 input_shape

model.add(GRU(1, implementation=1, activity_regularizer=regularizers.l1(0.01), return_sequences=True, input_shape=(len(x_train[0]),)))

但是,如果您想保留 reshape 的训练数据,那么您的输入形状将为 input_shape=(len(x_train[0]),1))

关于python - RNN 模型上的网格搜索时训练数据的形状不正确,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/59095447/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com