gpt4 book ai didi

python - 如何从 Keras 的 model.predict 函数获取预测标签?

转载 作者:行者123 更新时间:2023-12-01 09:11:37 24 4
gpt4 key购买 nike

我使用 Keras 库构建了一个 LSTM 模型来预测 Quora 官方数据集上的重复问题。测试标签为 0 或 1。1 表示问题对重复。使用 model.fit 构建模型后,我使用 model.predict 对测试数据测试模型。输出是一个值(概率)数组,如下所示:

 [ 0.00514298]
[ 0.15161049]
[ 0.27588326]
[ 0.00236167]
[ 1.80067325]
[ 0.01048524]
[ 1.43425131]
[ 1.99202418]
[ 0.54853892]
[ 0.02514757]

我只显示数组中的前 10 个值。我不明白这些值的含义是什么,以及如何将其与测试标签进行比较以计算测试准确性。我希望模型将二进制预测值输出为 0 或 1,而不是概率。请引用下面我的代码的最后一部分:

sequence_1_input = Input(shape=(MAX_SEQUENCE_LENGTH,), dtype='int32')
embedded_sequences_1 = embedding_layer(sequence_1_input)
x1 = lstm_layer(embedded_sequences_1)

sequence_2_input = Input(shape=(MAX_SEQUENCE_LENGTH,), dtype='int32')
embedded_sequences_2 = embedding_layer(sequence_2_input)
y1 = lstm_layer(embedded_sequences_2)

merged = concatenate([x1, y1])
merged = Dropout(rate_drop_dense)(merged)
merged = BatchNormalization()(merged)

merged = Dense(num_dense, activation=act)(merged)
merged = Dropout(rate_drop_dense)(merged)
merged = BatchNormalization()(merged)

preds = Dense(1, activation='sigmoid')(merged)


########################################
## train the model
########################################
model = Model(inputs=[sequence_1_input, sequence_2_input], \
outputs=preds)
model.compile(loss='binary_crossentropy',
optimizer='nadam',
metrics=['acc'])



hist = model.fit([data_1_train, data_2_train], labels_train, \
validation_data=([data_1_val, data_2_val], labels_val, weight_val), \
epochs=200, batch_size=2048, shuffle=True, \
class_weight=class_weight, callbacks=[early_stopping, model_checkpoint])


preds = model.predict([test_data_1, test_data_2], batch_size=8192,
verbose=1)
preds += model.predict([test_data_2, test_data_1], batch_size=8192,
verbose=1)
preds /= 2
print(type(preds))
print(preds[:20])
print('preds.ravel')
print(preds.ravel())

最佳答案

正如你所说,你的输出是一个带有概率的np数组。您可以通过执行例如 (model.predict(X) > 0.5).astype(int)

将其转换为二进制标签

关于python - 如何从 Keras 的 model.predict 函数获取预测标签?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/51618745/

24 4 0