gpt4 book ai didi

python - Keras:batch_size 的类型错误

转载 作者:行者123 更新时间:2023-11-30 08:33:35 25 4
gpt4 key购买 nike

我正在使用 Keras(使用 Python 3.6)来预测数组 (x_test) 的输出,但我得到了一个 TypeError 作为返回。

这是我的预测代码:

x_test = [[8],[6],[0],[2],[0],[0],[0],[0],[112.128],[0],[0],[2],[0],[1],[1],[2],[2]]
prediction = model.predict(model, x_test, batch_size = 32, verbose = 1)

这是我得到的错误:

TypeError                                 Traceback (most recent call last)
<ipython-input-14-286495dc15a7> in <module>()
1 x_test = [[8],[6],[0],[2],[0],[0],[0],[0],[112.128],[0],[0],[2],[0],[1],[1],[2],[2]]
2
----> 3 prediction = model.predict(model, x_test, batch_size =(17,1), verbose = 1)

TypeError: predict() got multiple values for argument 'batch_size'

如果有人对问题所在有任何建议,我们将不胜感激。

作为引用,这是我的神经网络,它似乎工作正常。

model = Sequential()

model.add(Dense(32, input_dim=17, init='uniform', activation='relu' ))
model.add(Dense(64, init='uniform', activation='relu'))
model.add(Dense(128, init='uniform', activation='relu'))
model.add(Dense(64, init='uniform', activation='sigmoid'))
model.add(Dense(32, init='uniform', activation='sigmoid'))
model.add(Dense(16, init='uniform', activation='sigmoid'))
model.add(Dense(8, init='uniform', activation='sigmoid'))
model.add(Dense(4, init='uniform', activation='sigmoid'))
model.add(Dense(1, init='uniform', activation='sigmoid'))

# Compile model
model.compile(loss='mean_squared_logarithmic_error', optimizer='SGD', metrics=['accuracy'])

# Fit model
history = model.fit(X, Y, nb_epoch=300, validation_split=0.2, batch_size=3)

非常感谢!

最佳答案

您不需要在 model.predict 中传递 model 参数,因为预测的默认值为 predict(self, x, batch_size=32 , verbose=0) 其中 modelself 自动定义。

所以你的代码应该是这样的:

prediction = model.predict(x_test, batch_size = 32, verbose = 1)

根据文档,x 应该是 numpy.array 而不是 list

Arguments:

x: the input data, as a Numpy array.

batch_size: integer.

verbose: verbosity mode, 0 or 1.

这意味着 x_test 应该是:

x_test = np.array([[8],[6],[0],[2],[0],[0],[0],[0],[112.128],[0],[0],[2],[0],[1],[1],[2],[2]])

关于python - Keras:batch_size 的类型错误,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/43058675/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com