gpt4 book ai didi

python-3.x - 使用预测生成器方法的分类报告

转载 作者:行者123 更新时间:2023-11-30 09:58:50 25 4
gpt4 key购买 nike

所以我正在使用每个字母的数据集制作一个本土语言翻译器。我对机器学习的了解很少,只制作了 2 类图像分类器。最初这些是我的代码,它工作正常,但只能向我显示混淆矩阵,我需要像 F1 分数这样的分类报告,但我似乎无法理解应该如何操作我的代码。

import numpy as np
from sklearn.linear_model import LogisticRegression
from tensorflow import keras, metrics
from tensorflow.keras.layers import Dense, GlobalAveragePooling2D
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.models import Model
from sklearn.metrics import confusion_matrix
import itertools
import matplotlib.pyplot as plt
from webencodings import labels
from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True

train_path=r'C:\Users\Acer\imagerec\BAYBAYIN\TRAIN'
valid_path=r'C:\Users\Acer\imagerec\BAYBAYIN\VAL'
test_path=r'C:\Users\Acer\imagerec\BAYBAYIN\TEST'

class_labels=['A', 'BA', 'KA', 'GA', 'HA', '1', '2', '3', '4', '5', '6', '7',
'8', '9', '10', '11', '12', '13', '14', '15', '16', '17', '18', '19',
'20', '21', '22', '23', '24', '25', '26', '28', '29', '30', '31', '32',
'33', '34', '35', '36', '37', '38', '39', '40', '41', '42', '43', '44']

train_batches=ImageDataGenerator(preprocessing_function=keras.applications.xception.preprocess_input)\
.flow_from_directory(train_path, target_size=(299,299),classes=class_labels,batch_size=5)
valid_batches=ImageDataGenerator(preprocessing_function=keras.applications.xception.preprocess_input)\
.flow_from_directory(valid_path, target_size=(299,299),classes=class_labels,batch_size=5)
test_batches=ImageDataGenerator(preprocessing_function=keras.applications.xception.preprocess_input)\
.flow_from_directory(test_path, target_size=(299,299),classes=class_labels,batch_size=5, shuffle=False)

base_model=keras.applications.vgg19.VGG19(include_top=False)

x=base_model.output
x=GlobalAveragePooling2D()(x)
x=Dense(1024, activation='relu')(x)
x=Dense(48, activation='softmax')(x)
model=Model(inputs=base_model.input, outputs=x)


base_model.trainable = False

N=1

print("HANG ON LEARNING IN PROGRESS...")

model.compile(Adam(lr=.0001),loss='categorical_crossentropy', metrics=['accuracy'])
history=model.fit_generator(train_batches, steps_per_epoch=1290, validation_data=valid_batches,
validation_steps=90,epochs=N,verbose=1)

print("[INFO]evaluating model...")

test_labels=test_batches.classes
predictions=model.predict_generator(test_batches, steps=28, verbose=1)


import matplotlib.pyplot as plt
import numpy as np


plt.imshow(np.random.random((48,48)), interpolation='nearest')
plt.xticks(np.arange(0,48), ['A', 'BA', 'KA', 'GA', 'HA', '1', '2', '3', '4', '5', '6', '7',
'8', '9', '10', '11', '12', '13', '14', '15', '16', '17', '18', '19',
'20', '21', '22', '23', '24', '25', '26', '28', '29', '30', '31', '32',
'33', '34', '35', '36', '37', '38', '39', '40', '41', '42', '43', '44'])
plt.yticks(np.arange(0,48),['A', 'BA', 'KA', 'GA', 'HA', '1', '2', '3', '4', '5', '6', '7',
'8', '9', '10', '11', '12', '13', '14', '15', '16', '17', '18', '19',
'20', '21', '22', '23', '24', '25', '26', '28', '29', '30', '31', '32',
'33', '34', '35', '36', '37', '38', '39', '40', '41', '42', '43', '44'])



plt.show()
model.save("X19baybayin.h5")

我如何使用预测,或者我可以使用它作为我的 y-pred 以及我应该使用什么作为 y-true

最佳答案

TL;博士

  1. test_batches.classes 为 y_true
  2. np.argmax(预测,axis=-1) as y_pred
<小时/>

我假设每个样本有一个类,因为您使用的是“softmax”和“categorical_crossentropy”,并且您需要为每个样本获取最佳相关(一个)类(多类分类问题)。

澄清一下:

# import classification_report
from sklearn.metrics import classification_report

# get the ground truth of your data.
test_labels=test_batches.classes

# predict the probability distribution of the data
predictions=model.predict_generator(test_batches, steps=28, verbose=1)

# get the class with highest probability for each sample
y_pred = np.argmax(predictions, axis=-1)

# get the classification report
print(classification_report(test_labels, y_pred))

注意:predict_generator 将被弃用,请改用 model.predict。

关于python-3.x - 使用预测生成器方法的分类报告,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/59909304/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com