gpt4 book ai didi

java - org.deeplearning4j.exception.DL4JInvalidInputException : Problem with creating array with data to predict

转载 作者:行者123 更新时间:2023-12-01 21:14:13 25 4
gpt4 key购买 nike

当我运行下面的代码时:

model.output(samples).getDouble(0);

我收到错误:

org.deeplearning4j.exception.DL4JInvalidInputException: 
Cannot do forward pass in Convolution layer (layer name = conv1d_1, layer index = 0): input array channels does not match CNN layer configuration
(data input channels = 80, [minibatch,inputDepth,height,width]=[1, 80, 3, 1]; expected input channels = 3)
(layer name: conv1d_1, layer index: 0, layer type: Convolution1DLayer)

我将数据创建为 float[] 数组,长度 = 240。创建 INDArray:

 INDArray features = Nd4j.create(data, new int[]{1, 240}, 'c');

这是我的 keras 模型:

model = Sequential()
model.add(Reshape((const.PERIOD, const.N_FEATURES), input_shape=(240,)))
model.add(Conv1D(100, 10, activation='relu', input_shape=(const.PERIOD, const.N_FEATURES)))
model.add(Conv1D(100, 10, activation='relu'))
model.add(MaxPooling1D(const.N_FEATURES))
model.add(Conv1D(160, 10, activation='relu'))
model.add(Conv1D(160, 10, activation='relu'))
model.add(Flatten())
model.add(Dropout(0.5))
model.add(Dense(7, activation='softmax'))
model.summary()
model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])

其中 PERIOD = 80,N_FEATURES = 3

如果我将形状设置为:

 INDArray features = Nd4j.create(data, new int[]{240, 1});

那么错误是:

IllegalStateException: Input shape [240, 1] and output shape[240, 1] do not match
at org.deeplearning4j.nn.modelimport.keras.preprocessors.ReshapePreprocessor.preProcess(ReshapePreprocessor.java:103)
at org.deeplearning4j.nn.multilayer.MultiLayerNetwork.outputOfLayerDetached(MultiLayerNetwork.java:1256)
at org.deeplearning4j.nn.multilayer.MultiLayerNetwork.output(MultiLayerNetwork.java:2340)
at org.deeplearning4j.nn.multilayer.MultiLayerNetwork.output(MultiLayerNetwork.java:2303)
at org.deeplearning4j.nn.multilayer.MultiLayerNetwork.output(MultiLayerNetwork.java:2294)
at org.deeplearning4j.nn.multilayer.MultiLayerNetwork.output(MultiLayerNetwork.java:2281)
at org.deeplearning4j.nn.multilayer.MultiLayerNetwork.output(MultiLayerNetwork.java:2377)

最佳答案

您可以提出问题吗?这看起来像一个错误。谢谢。 https://github.com/eclipse/deeplearning4j/issues

关于java - org.deeplearning4j.exception.DL4JInvalidInputException : Problem with creating array with data to predict,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/58885964/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com