gpt4 book ai didi

python - 预测新结果时检查模型输入 keras 时出错

转载 作者:行者123 更新时间:2023-11-30 09:46:28 25 4
gpt4 key购买 nike

我正在尝试使用基于新数据构建的 keras 模型,但在尝试预测时出现输入错误。

这是我的模型代码:

def build_model(max_features, maxlen):
"""Build LSTM model"""
model = Sequential()
model.add(Embedding(max_features, 128, input_length=maxlen))
model.add(LSTM(128))
model.add(Dropout(0.5))
model.add(Dense(1))
model.add(Activation('sigmoid'))

model.compile(loss='binary_crossentropy',
optimizer='rmsprop')

return model

以及我的代码来预测新数据的输出预测:

LSTM_model = load_model('LSTMmodel.h5')
data = pickle.load(open('traindata.pkl', 'rb'))


#### LSTM ####

"""Run train/test on logistic regression model"""

# Extract data and labels
X = [x[1] for x in data]
labels = [x[0] for x in data]

# Generate a dictionary of valid characters
valid_chars = {x:idx+1 for idx, x in enumerate(set(''.join(X)))}

max_features = len(valid_chars) + 1
maxlen = np.max([len(x) for x in X])

# Convert characters to int and pad
X = [[valid_chars[y] for y in x] for x in X]
X = sequence.pad_sequences(X, maxlen=maxlen)

# Convert labels to 0-1
y = [0 if x == 'benign' else 1 for x in labels]


y_pred = LSTM_model.predict(X)

运行此代码时出现的错误:

ValueError: Error when checking input: expected embedding_1_input to have shape (57,) but got array with shape (36,)

我的错误来自 maxlen,因为对于我的训练数据 maxlen=57 和我的新数据 maxlen=36

因此我尝试在预测代码中设置 maxlen=57 但随后出现此错误:

tensorflow.python.framework.errors_impl.InvalidArgumentError: indices[31,53] = 38 is not in [0, 38)
[[Node: embedding_1/embedding_lookup = GatherV2[Taxis=DT_INT32, Tindices=DT_INT32, Tparams=DT_FLOAT, _device="/job:localhost/replica:0/task:0/device:CPU:0"](embedding_1/embeddings/read, embedding_1/Cast, embedding_1/embedding_lookup/axis)]]

我应该怎么做才能解决这些问题?更改我的嵌入层?

最佳答案

将嵌入层的 input_length 设置为您在数据集中看到的最大长度,或者仅使用您在构建模型时使用的相同 maxlenpad_sequences。在这种情况下,任何短于 maxlen 的序列都将被填充,任何长于 maxlen 的序列将被截断。

进一步确保您使用的功能在训练和测试时间相同(即它们的数量不应改变)。

关于python - 预测新结果时检查模型输入 keras 时出错,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/51896013/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com