gpt4 book ai didi

Python keras - 模型层出现错误

转载 作者:太空宇宙 更新时间:2023-11-03 15:14:57 25 4
gpt4 key购买 nike

我是使用人工神经网络进行分类问题的新手

我有一个分类问题,其中输入数据是 8 列小数值,这些值是度量,输出数据是 8 列整数值,这些值是对象

INPUTS

785.39 6.30 782.75 771.82 7.53 -94.86 378.66 771.82
.
.
.

OUTPUTS

8 9 5 7 3 1 6 2
.
.
.

训练数据的记录为800条,测试数据的记录为200条

这是代码

import numpy
from keras.models import Sequential
from keras.layers import Dense
from keras.utils import np_utils

seed = 7
numpy.random.seed(seed)
datasetTrain = numpy.loadtxt("train.csv", delimiter=",")
datasetTest = numpy.loadtxt("test.csv", delimiter=",")

X_train = datasetTrain[:,0:7]
y_train = datasetTrain[:,8:15]

X_test = datasetTest[:,0:7]
y_test = datasetTest[:,8:15]

y_train = np_utils.to_categorical(y_train)
y_test = np_utils.to_categorical(y_test)

def baseline_model():
# create model
model = Sequential()
model.add(Dense(7, input_dim=7, kernel_initializer='normal', activation='relu'))
model.add(Dense((5593, 785), kernel_initializer='normal', activation='softmax'))
# Compile model
model.compile(loss='categorical_crossentropy', optimizer='adam')
return model

model = baseline_model()
model.fit(X_train, y_train, validation_data=(X_test, y_test), epochs=400,
batch_size=200, verbose=25)
scores = model.evaluate(X_test, y_test, verbose=0)
print("Baseline Error: %.2f%%" % (100-scores[1]*100))

我收到此错误

Traceback (most recent call last):
File "proyecto.py", line 29, in <module>
model = baseline_model()
File "proyecto.py", line 24, in baseline_model
model.add(Dense((5593, 785), kernel_initializer='normal', activation='softmax'))
ValueError: setting an array element with a sequence.

哪个模型最适合此数据?

最佳答案

这部分:

model.add(Dense((5593, 785), kernel_initializer='normal', activation='softmax'))

是错误的,Dense的第一个参数是输出神经元的数量,它应该是标量,而不是元组或向量。如果您想要 2D 形状的输出,则可以使用 Reshape 图层来 reshape 输出并执行以下操作:

model.add(Dense(5593 * 785, kernel_initializer='normal', activation='softmax'))
model.add(Reshape((5593, 785)))

关于Python keras - 模型层出现错误,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/43944169/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com