gpt4 book ai didi

python - 如何确定keras中的输入形状?

转载 作者:太空宇宙 更新时间:2023-11-03 15:33:20 25 4
gpt4 key购买 nike

我在构建深度学习模型时很难找到我的错误所在,但我通常在设置输入层输入形状时遇到问题。

这是我的模型:

model = Sequential([
Dense(32, activation='relu', input_shape=(1461, 75)),
Dense(32, activation='relu'),
Dense(ytrain.size),])

它返回以下错误:

 ValueError: Error when checking input: expected dense_1_input to have 3

dimensions, but got array with shape (1461, 75)

数组是kaggle房价竞赛的训练集,我的数据集有75列1461行。我的数组是 2 维的,那么为什么需要 3 维呢?我尝试在第一个密集层之前添加一个冗余的第 3 维 1 或展平数组,但错误只是变成了:

ValueError: Input 0 is incompatible with layer flatten_1: expected 

min_ndim=3, found ndim=2

您如何确定输入大小应该是多少以及为什么它期望的尺寸看起来如此随意?

作为引用,我附上了我的其余代码:

xtrain = pd.read_csv("pricetrain.csv")
test = pd.read_csv("pricetest.csv")
xtrain.fillna(xtrain.mean(), inplace=True)
xtrain.drop(["Alley"], axis=1, inplace=True)
xtrain.drop(["PoolQC"], axis=1, inplace=True)
xtrain.drop(["Fence"], axis=1, inplace=True)
xtrain.drop(["MiscFeature"], axis=1, inplace=True)
xtrain.drop(["PoolArea"], axis=1, inplace=True)
columns = list(xtrain)
for i in columns:
if xtrain[i].dtypes == 'object':
xtrain[i] = pd.Categorical(pd.factorize(xtrain[i])[0])
from sklearn import preprocessing

le = preprocessing.LabelEncoder()
for i in columns:
if xtrain[i].dtypes == 'object':
xtrain[i] = le.fit_transform(xtrain[i])
ytrain = xtrain["SalePrice"]
xtrain.drop(["SalePrice"], axis=1, inplace=True)
ytrain = ytrain.values
xtrain = xtrain.values
ytrain.astype("float32")

size = xtrain.size
print(ytrain)
model = Sequential(
[Flatten(),
Dense(32, activation='relu', input_shape=(109575,)),
Dense(32, activation='relu'),
Dense(ytrain.size),
])
model.compile(loss='mse', optimizer='adam')
model.fit(xtrain, ytrain, epochs=10, verbose=1)

任何建议都会非常有帮助!

谢谢。

最佳答案

第 0 维(样本轴)由训练的 batch_size 决定。您在定义输入形状时将其省略。这是有道理的,否则您的模型将取决于数据集中的样本数

输出也是如此。您似乎只预测每个示例的单个值 ("SalePrice")。所以输出层的形状为 1。

model = Sequential([
Dense(32, activation='relu', input_shape=(75, )),
Dense(32, activation='relu'),
Dense(1),
])

关于python - 如何确定keras中的输入形状?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/56565281/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com